blob: fa92b05314a2454cca7414f633da90f3fc88daad [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperband diagonal using MNIST\n",
"\n",
"Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP with a synchronous barrier. Uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time, to be more efficient with cluster resources. \n",
"\n",
"Model architecture based on https://keras.io/examples/mnist_transfer_cnn/ \n",
"\n",
"To load images into tables we use the script called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning which uses the Python Imaging Library so supports multiple formats http://www.pythonware.com/products/pil/\n",
"\n",
"## Table of contents\n",
"<a href=\"#import_libraries\">1. Import libraries</a>\n",
"\n",
"<a href=\"#load_and_prepare_data\">2. Load and prepare data</a>\n",
"\n",
"<a href=\"#image_preproc\">3. Call image preprocessor</a>\n",
"\n",
"<a href=\"#define_and_load_model\">4. Define and load model architecture</a>\n",
"\n",
"<a href=\"#hyperband\">5. Hyperband diagonal</a>\n",
"\n",
"<a href=\"#plot\">6. Plot results</a>\n",
"\n",
"<a href=\"#print\">7. Print run schedules (display only)</a>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
"#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
"\n",
"# Greenplum Database 5.x on GCP - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib\n",
"\n",
"# psycopg2 connection\n",
"import psycopg2 as p2\n",
"#conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
"conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
"cur = conn.cursor()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"import_libraries\"></a>\n",
"# 1. Import libraries\n",
"From https://keras.io/examples/mnist_transfer_cnn/ import libraries and define some params"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n"
]
}
],
"source": [
"from __future__ import print_function\n",
"\n",
"import datetime\n",
"import keras\n",
"from keras.datasets import mnist\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Activation, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras import backend as K\n",
"\n",
"now = datetime.datetime.now\n",
"\n",
"#batch_size = 128\n",
"num_classes = 10\n",
"#epochs = 5\n",
"\n",
"# input image dimensions\n",
"img_rows, img_cols = 28, 28\n",
"# number of convolutional filters to use\n",
"filters = 32\n",
"# size of pooling area for max pooling\n",
"pool_size = 2\n",
"# convolution kernel size\n",
"kernel_size = 3\n",
"\n",
"if K.image_data_format() == 'channels_first':\n",
" input_shape = (1, img_rows, img_cols)\n",
"else:\n",
" input_shape = (img_rows, img_cols, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Others needed in this workbook"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load_and_prepare_data\"></a>\n",
"# 2. Load and prepare data\n",
"\n",
"First load MNIST data from Keras, consisting of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10000, 28, 28)\n",
"(10000, 28, 28, 1)\n"
]
}
],
"source": [
"# the data, split between train and test sets\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"\n",
"# reshape to match model architecture\n",
"print(x_test.shape)\n",
"x_train = x_train.reshape(len(x_train), *input_shape)\n",
"x_test = x_test.reshape(len(x_test), *input_shape)\n",
"print(x_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load datasets into tables using image loader scripts called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# MADlib tools directory\n",
"import sys\n",
"import os\n",
"madlib_site_dir = '/Users/fmcquillan/Documents/Product/MADlib/Demos/data'\n",
"sys.path.append(madlib_site_dir)\n",
"\n",
"# Import image loader module\n",
"from madlib_image_loader import ImageLoader, DbCredentials"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# Specify database credentials, for connecting to db\n",
"db_creds = DbCredentials(user='gpadmin',\n",
" host='localhost',\n",
" port='8000',\n",
" password='')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# Initialize ImageLoader (increase num_workers to run faster)\n",
"iloader = ImageLoader(num_workers=5, db_creds=db_creds)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE train_mnist (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table train_mnist in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-11 [pid 34068]\n",
"PoolWorker-11: Created temporary directory /tmp/madlib_RbuQlbqxI5\n",
"Initializing PoolWorker-12 [pid 34069]\n",
"PoolWorker-12: Created temporary directory /tmp/madlib_tEyH9GMFGV\n",
"Initializing PoolWorker-13 [pid 34070]\n",
"PoolWorker-13: Created temporary directory /tmp/madlib_TyYs4viAVD\n",
"Initializing PoolWorker-14 [pid 34071]\n",
"Initializing PoolWorker-15 [pid 34072]\n",
"PoolWorker-14: Created temporary directory /tmp/madlib_KTwnncRsaq\n",
"PoolWorker-15: Created temporary directory /tmp/madlib_jtG9zAC8HU\n",
"PoolWorker-11: Connected to madlib db.\n",
"PoolWorker-13: Connected to madlib db.\n",
"PoolWorker-14: Connected to madlib db.\n",
"PoolWorker-12: Connected to madlib db.\n",
"PoolWorker-15: Connected to madlib db.\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0000.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0000.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0000.tmp\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0000.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0000.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0001.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0001.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0001.tmp\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0001.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0001.tmp\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0002.tmp\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0002.tmp\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0002.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0002.tmp\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0002.tmp\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0003.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0003.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0003.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0003.tmp\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0003.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0004.tmp\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0004.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0004.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0004.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0004.tmp\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0005.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0005.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0005.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0005.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0005.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0006.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0006.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0006.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0006.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0006.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0007.tmp\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0007.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0007.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0007.tmp\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0007.tmp\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0008.tmp\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0008.tmp\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0008.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0008.tmp\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0008.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0009.tmp\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0009.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0009.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0009.tmp\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0009.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0010.tmp\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0010.tmp\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0010.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0010.tmp\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0010.tmp\n",
"PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0011.tmp\n",
"PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0011.tmp\n",
"PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0011.tmp\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0011.tmp\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-13: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Loaded 1000 images into train_mnist\n",
"PoolWorker-14: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0011.tmp\n",
"PoolWorker-12: Loaded 1000 images into train_mnist\n",
"PoolWorker-15: Loaded 1000 images into train_mnist\n",
"PoolWorker-11: Removed temporary directory /tmp/madlib_RbuQlbqxI5\n",
"PoolWorker-15: Removed temporary directory /tmp/madlib_jtG9zAC8HU\n",
"PoolWorker-12: Removed temporary directory /tmp/madlib_tEyH9GMFGV\n",
"PoolWorker-13: Removed temporary directory /tmp/madlib_TyYs4viAVD\n",
"PoolWorker-14: Removed temporary directory /tmp/madlib_KTwnncRsaq\n",
"Done! Loaded 60000 images in 45.7068669796s\n",
"5 workers terminated.\n",
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE test_mnist (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table test_mnist in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-16 [pid 34074]\n",
"PoolWorker-16: Created temporary directory /tmp/madlib_MjwU1yRoMW\n",
"Initializing PoolWorker-17 [pid 34075]\n",
"PoolWorker-17: Created temporary directory /tmp/madlib_kTezv88uWu\n",
"Initializing PoolWorker-18 [pid 34076]\n",
"PoolWorker-18: Created temporary directory /tmp/madlib_TFIofbewK1\n",
"Initializing PoolWorker-19 [pid 34077]\n",
"PoolWorker-19: Created temporary directory /tmp/madlib_QUIRxlckvj\n",
"PoolWorker-20: Created temporary directory /tmp/madlib_Eii5YFUzCZ\n",
"Initializing PoolWorker-20 [pid 34078]\n",
"PoolWorker-17: Connected to madlib db.\n",
"PoolWorker-18: Connected to madlib db.\n",
"PoolWorker-19: Connected to madlib db.\n",
"PoolWorker-16: Connected to madlib db.\n",
"PoolWorker-20: Connected to madlib db.\n",
"PoolWorker-18: Wrote 1000 images to /tmp/madlib_TFIofbewK1/test_mnist0000.tmp\n",
"PoolWorker-19: Wrote 1000 images to /tmp/madlib_QUIRxlckvj/test_mnist0000.tmp\n",
"PoolWorker-17: Wrote 1000 images to /tmp/madlib_kTezv88uWu/test_mnist0000.tmp\n",
"PoolWorker-16: Wrote 1000 images to /tmp/madlib_MjwU1yRoMW/test_mnist0000.tmp\n",
"PoolWorker-20: Wrote 1000 images to /tmp/madlib_Eii5YFUzCZ/test_mnist0000.tmp\n",
"PoolWorker-18: Loaded 1000 images into test_mnist\n",
"PoolWorker-17: Loaded 1000 images into test_mnist\n",
"PoolWorker-19: Loaded 1000 images into test_mnist\n",
"PoolWorker-18: Wrote 1000 images to /tmp/madlib_TFIofbewK1/test_mnist0001.tmp\n",
"PoolWorker-16: Loaded 1000 images into test_mnist\n",
"PoolWorker-20: Loaded 1000 images into test_mnist\n",
"PoolWorker-18: Loaded 1000 images into test_mnist\n",
"PoolWorker-19: Wrote 1000 images to /tmp/madlib_QUIRxlckvj/test_mnist0001.tmp\n",
"PoolWorker-17: Wrote 1000 images to /tmp/madlib_kTezv88uWu/test_mnist0001.tmp\n",
"PoolWorker-16: Wrote 1000 images to /tmp/madlib_MjwU1yRoMW/test_mnist0001.tmp\n",
"PoolWorker-20: Wrote 1000 images to /tmp/madlib_Eii5YFUzCZ/test_mnist0001.tmp\n",
"PoolWorker-19: Loaded 1000 images into test_mnist\n",
"PoolWorker-17: Loaded 1000 images into test_mnist\n",
"PoolWorker-16: Loaded 1000 images into test_mnist\n",
"PoolWorker-20: Loaded 1000 images into test_mnist\n",
"PoolWorker-16: Removed temporary directory /tmp/madlib_MjwU1yRoMW\n",
"PoolWorker-19: Removed temporary directory /tmp/madlib_QUIRxlckvj\n",
"PoolWorker-17: Removed temporary directory /tmp/madlib_kTezv88uWu\n",
"PoolWorker-20: Removed temporary directory /tmp/madlib_Eii5YFUzCZ\n",
"PoolWorker-18: Removed temporary directory /tmp/madlib_TFIofbewK1\n",
"Done! Loaded 10000 images in 6.80017995834s\n",
"5 workers terminated.\n"
]
}
],
"source": [
"# Drop tables\n",
"%sql DROP TABLE IF EXISTS train_mnist, test_mnist\n",
"\n",
"# Save images to temporary directories and load into database\n",
"iloader.load_dataset_from_np(x_train, y_train, 'train_mnist', append=False)\n",
"iloader.load_dataset_from_np(x_test, y_test, 'test_mnist', append=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"image_preproc\"></a>\n",
"# 3. Call image preprocessor\n",
"\n",
"Transforms from one image per row to multiple images per row for batch optimization. Also normalizes and one-hot encodes.\n",
"\n",
"Training dataset"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" </tr>\n",
" <tr>\n",
" <td>train_mnist</td>\n",
" <td>train_mnist_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>text</td>\n",
" <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
" <td>1000</td>\n",
" <td>255.0</td>\n",
" <td>10</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'train_mnist', u'train_mnist_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 1000, 255.0, 10)]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS train_mnist_packed, train_mnist_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('train_mnist', -- Source table\n",
" 'train_mnist_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" 1000, -- Buffer size\n",
" 255 -- Normalizing constant\n",
" );\n",
"\n",
"SELECT * FROM train_mnist_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test dataset"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" </tr>\n",
" <tr>\n",
" <td>test_mnist</td>\n",
" <td>test_mnist_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>text</td>\n",
" <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
" <td>5000</td>\n",
" <td>255.0</td>\n",
" <td>10</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'test_mnist', u'test_mnist_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 5000, 255.0, 10)]"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS test_mnist_packed, test_mnist_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('test_mnist', -- Source table\n",
" 'test_mnist_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" 'train_mnist_packed' -- Training preproc table\n",
" );\n",
"\n",
"SELECT * FROM test_mnist_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"define_and_load_model\"></a>\n",
"# 4. Define and load model architecture\n",
"\n",
"Model with feature and classification layers trainable"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 26, 26, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 24, 24, 32) 9248 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 24, 24, 32) 0 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 4608) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 128) 589952 \n",
"_________________________________________________________________\n",
"activation_3 (Activation) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 1290 \n",
"_________________________________________________________________\n",
"activation_4 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 600,810\n",
"Trainable params: 600,810\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# define two groups of layers: feature (convolutions) and classification (dense)\n",
"feature_layers = [\n",
" Conv2D(filters, kernel_size,\n",
" padding='valid',\n",
" input_shape=input_shape),\n",
" Activation('relu'),\n",
" Conv2D(filters, kernel_size),\n",
" Activation('relu'),\n",
" MaxPooling2D(pool_size=pool_size),\n",
" Dropout(0.25),\n",
" Flatten(),\n",
"]\n",
"\n",
"classification_layers = [\n",
" Dense(128),\n",
" Activation('relu'),\n",
" Dropout(0.5),\n",
" Dense(num_classes),\n",
" Activation('softmax')\n",
"]\n",
"\n",
"# create complete model\n",
"model = Sequential(feature_layers + classification_layers)\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table using psycopg2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>name</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>feature + classification layers trainable</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'feature + classification layers trainable')]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql DROP TABLE IF EXISTS model_arch_table_mnist;\n",
"query = \"SELECT madlib.load_keras_model('model_arch_table_mnist', %s, NULL, %s)\"\n",
"cur.execute(query,[model.to_json(), \"feature + classification layers trainable\"])\n",
"conn.commit()\n",
"\n",
"# check model loaded OK\n",
"%sql SELECT model_id, name FROM model_arch_table_mnist;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"hyperband\"></a>\n",
"# 5. Hyperband diagonal"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create tables"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"-- overall results table\n",
"DROP TABLE IF EXISTS results_mnist;\n",
"CREATE TABLE results_mnist ( \n",
" mst_key INTEGER, -- note not SERIAL\n",
" model_id INTEGER, \n",
" compile_params TEXT,\n",
" fit_params TEXT, \n",
" model_type TEXT, \n",
" model_size DOUBLE PRECISION, \n",
" metrics_elapsed_time DOUBLE PRECISION[], \n",
" metrics_type TEXT[], \n",
" training_metrics_final DOUBLE PRECISION, \n",
" training_loss_final DOUBLE PRECISION, \n",
" training_metrics DOUBLE PRECISION[], \n",
" training_loss DOUBLE PRECISION[], \n",
" validation_metrics_final DOUBLE PRECISION, \n",
" validation_loss_final DOUBLE PRECISION, \n",
" validation_metrics DOUBLE PRECISION[], \n",
" validation_loss DOUBLE PRECISION[], \n",
" model_arch_table TEXT, \n",
" num_iterations INTEGER, \n",
" start_training_time TIMESTAMP, \n",
" end_training_time TIMESTAMP,\n",
" s INTEGER, \n",
" i INTEGER,\n",
" run_id SERIAL\n",
" );\n",
"\n",
"-- model selection table\n",
"DROP TABLE IF EXISTS mst_table_hb_mnist;\n",
"CREATE TABLE mst_table_hb_mnist (\n",
" mst_key SERIAL, \n",
" s INTEGER, -- bracket\n",
" model_id INTEGER, \n",
" compile_params VARCHAR, \n",
" fit_params VARCHAR\n",
" );\n",
"\n",
"-- model selection summary table\n",
"DROP TABLE IF EXISTS mst_table_hb_mnist_summary;\n",
"CREATE TABLE mst_table_hb_mnist_summary (model_arch_table VARCHAR);\n",
"INSERT INTO mst_table_hb_mnist_summary VALUES ('model_arch_table_mnist');\n",
"\n",
"-- model selection table for diagonal\n",
"DROP TABLE IF EXISTS mst_diag_table_hb_mnist;\n",
"CREATE TABLE mst_diag_table_hb_mnist (\n",
" mst_key INTEGER, -- note not SERIAL since this table derived from main model selection table\n",
" s INTEGER, -- bracket\n",
" model_id INTEGER, \n",
" compile_params VARCHAR, \n",
" fit_params VARCHAR\n",
" );\n",
"\n",
"-- model selection summary table for diagonal\n",
"DROP TABLE IF EXISTS mst_diag_table_hb_mnist_summary;\n",
"CREATE TABLE mst_diag_table_hb_mnist_summary (model_arch_table VARCHAR);\n",
"INSERT INTO mst_diag_table_hb_mnist_summary VALUES ('model_arch_table_mnist');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generalize table names"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"results_table = 'results_mnist'\n",
"\n",
"output_table = 'mnist_multi_model'\n",
"output_table_info = '_'.join([output_table, 'info'])\n",
"output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
"mst_table = 'mst_table_hb_mnist'\n",
"mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
"mst_diag_table = 'mst_diag_table_hb_mnist'\n",
"mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
"\n",
"model_arch_table = 'model_arch_library_mnist'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Hyperband diagonal logic"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from random import random\n",
"from math import log, ceil\n",
"from time import time, ctime\n",
"\n",
"class Hyperband_diagonal:\n",
" \n",
" def __init__( self, get_params_function, try_params_function ):\n",
" self.get_params = get_params_function\n",
" self.try_params = try_params_function\n",
"\n",
" self.max_iter = 9 # maximum iterations per configuration\n",
" self.eta = 3 # defines downsampling rate (default = 3)\n",
"\n",
" self.logeta = lambda x: log( x ) / log( self.eta )\n",
" self.s_max = int( self.logeta( self.max_iter ))\n",
" self.B = ( self.s_max + 1 ) * self.max_iter\n",
" self.setup_full_schedule()\n",
" self.create_mst_superset()\n",
" \n",
" self.best_loss = np.inf\n",
" self.best_accuracy = 0.0\n",
"\n",
" \n",
" # create full Hyperband schedule for all brackets ahead of time\n",
" def setup_full_schedule(self):\n",
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
" \n",
" print (\" \")\n",
" print (\"Hyperband brackets\")\n",
"\n",
" # loop through each bracket in reverse order\n",
" for s in reversed(range(self.s_max+1)):\n",
" \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
" print (\"------------\")\n",
" counter = 0\n",
"\n",
" n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" #### Begin Finite Horizon Successive Halving with (n,r)\n",
" for i in range(s+1):\n",
" # n_i configs for r_i iterations\n",
" n_i = n*self.eta**(-i)\n",
" r_i = r*self.eta**(i)\n",
"\n",
" self.n_vals[s][i] = n_i\n",
" self.r_vals[s][i] = r_i\n",
"\n",
" print (str(n_i) + \" \" + str (r_i))\n",
"\n",
" # check if leaf node for this s\n",
" if counter == s:\n",
" sum_leaf_n_i += n_i\n",
" counter += 1\n",
" \n",
" #### End Finite Horizon Successive Halving with (n,r)\n",
"\n",
" #print (\" \")\n",
" #print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
" #print (\"(if have more workers than this, they may not be 100% busy)\")\n",
" \n",
" \n",
" # generate model selection tuples for all brackets\n",
" def create_mst_superset(self):\n",
" # get hyper parameter configs for each bracket s\n",
" for s in reversed(range(self.s_max+1)):\n",
" n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" print (\" \")\n",
" print (\"Create superset of MSTs, i.e., i=0 for for each bracket s\")\n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n=\" + str(n))\n",
" print (\"r=\" + str(r))\n",
" print (\" \")\n",
" \n",
" # n random configurations for each bracket s\n",
" self.get_params(n, s)\n",
" \n",
" \n",
" # Hyperband diagonal logic\n",
" def run( self, skip_last = 0, dry_run = False ): \n",
" \n",
" print (\" \")\n",
" print (\"Hyperband diagonal\")\n",
" print (\"outer loop on diagonal:\")\n",
" \n",
" # outer loop on diagonal\n",
" for i in range(self.s_max+1):\n",
" print (\" \")\n",
" print (\"i=\" + str(i))\n",
" \n",
" # zero out diagonal table\n",
" %sql TRUNCATE TABLE $mst_diag_table\n",
" \n",
" # loop on brackets s desc to create diagonal table\n",
" print (\"loop on s desc to create diagonal table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
"\n",
" # build up mst table for diagonal\n",
" %sql INSERT INTO $mst_diag_table (SELECT * FROM $mst_table WHERE s=$s);\n",
" \n",
" # first pass\n",
" if i == 0:\n",
" first_pass = True\n",
" else:\n",
" first_pass = False\n",
" \n",
" # multi-model training\n",
" print (\" \")\n",
" print (\"try params for i = \" + str(i))\n",
" U = self.try_params(i, self.r_vals[self.s_max][i], first_pass) # r_i is the same for all diagonal elements\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
" # don't need to prune if finished last diagonal\n",
" if i < self.s_max:\n",
" print (\"loop on s desc to prune mst table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
" \n",
" # compute number of configs to keep\n",
" # remember i value is different for each bracket s on the diagonal\n",
" k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
" print (\"pruning s = {} with k = {}\".format(s, k))\n",
"\n",
" # temporarily re-define table names due to weird Python scope issues\n",
" results_table = 'results_mnist'\n",
"\n",
" output_table = 'mnist_multi_model'\n",
" output_table_info = '_'.join([output_table, 'info'])\n",
" output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
" mst_table = 'mst_table_hb_mnist'\n",
" mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
" mst_diag_table = 'mst_diag_table_hb_mnist'\n",
" mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
"\n",
" model_arch_table = 'model_arch_library_mnist'\n",
" \n",
" query = \"\"\"\n",
" DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
" \"\"\".format(**locals())\n",
" cur.execute(query)\n",
" conn.commit()\n",
" \n",
" # these were not working so used cursor instead\n",
" #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
" #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
" \n",
" # keep track of best loss so far (for display purposes only)\n",
" loss = %sql SELECT validation_loss_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
" accuracy = %sql SELECT validation_metrics_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
" \n",
" if loss < self.best_loss:\n",
" self.best_loss = loss\n",
" self.best_accuracy = accuracy\n",
" \n",
" print (\" \")\n",
" print (\"best validation loss so far = \" + str(loss))\n",
" print (\"best validation accuracy so far = \" + str(accuracy))\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate params and insert into MST table"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"def get_params(n, s):\n",
" \n",
" from sklearn.model_selection import ParameterSampler\n",
" from scipy.stats.distributions import uniform\n",
" import numpy as np\n",
" \n",
" # model architecture\n",
" model_id = [1]\n",
"\n",
" # compile params\n",
" # loss function\n",
" loss = ['categorical_crossentropy']\n",
" # optimizer\n",
" optimizer = ['Adam', 'SGD']\n",
" # learning rate (sample on log scale here not in ParameterSampler)\n",
" lr_range = [0.001, 0.1]\n",
" lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n)\n",
" # metrics\n",
" metrics = ['accuracy']\n",
"\n",
" # fit params\n",
" # batch size\n",
" batch_size = [32, 64, 128]\n",
" # epochs\n",
" epochs = [1]\n",
"\n",
" # create random param list\n",
" param_grid = {\n",
" 'model_id': model_id,\n",
" 'loss': loss,\n",
" 'optimizer': optimizer,\n",
" 'lr': lr,\n",
" 'metrics': metrics,\n",
" 'batch_size': batch_size,\n",
" 'epochs': epochs\n",
" }\n",
" param_list = list(ParameterSampler(param_grid, n_iter=n))\n",
" \n",
" for params in param_list:\n",
"\n",
" model_id = str(params.get(\"model_id\"))\n",
" compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
" fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
" row_content = \"(\" + str(s) + \", \" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
" \n",
" %sql INSERT INTO $mst_table (s, model_id, compile_params, fit_params) VALUES $row_content\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run model hopper for candidates in MST table"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"def try_params(i, r, first_pass):\n",
" \n",
" # multi-model fit\n",
" if first_pass:\n",
" # cold start\n",
" %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
" # passing vars as madlib args does not seem to work\n",
" #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
" %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed');\n",
"\n",
" else:\n",
" # warm start to continue from previous run\n",
" %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed', NULL, True);\n",
"\n",
" # save results via temp table\n",
" # add everything from info table\n",
" %sql DROP TABLE IF EXISTS temp_results;\n",
" %sql CREATE TABLE temp_results AS (SELECT * FROM $output_table_info);\n",
" \n",
" # add summary table info and i value (same for each row)\n",
" %sql ALTER TABLE temp_results ADD COLUMN model_arch_table TEXT, ADD COLUMN num_iterations INTEGER, ADD COLUMN start_training_time TIMESTAMP, ADD COLUMN end_training_time TIMESTAMP, ADD COLUMN s INTEGER, ADD COLUMN i INTEGER;\n",
" %sql UPDATE temp_results SET model_arch_table = (SELECT model_arch_table FROM $output_table_summary), num_iterations = (SELECT num_iterations FROM $output_table_summary), start_training_time = (SELECT start_training_time FROM $output_table_summary), end_training_time = (SELECT end_training_time FROM $output_table_summary), i = $i;\n",
" \n",
" # get the s value for each run (not the same for each row since diagonal table crosses multiple brackets)\n",
" %sql UPDATE temp_results SET s = m.s FROM mst_diag_table_hb_mnist AS m WHERE m.mst_key = temp_results.mst_key;\n",
" \n",
" # copy temp table into results table\n",
" %sql INSERT INTO $results_table (SELECT * FROM temp_results);\n",
"\n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call Hyperband diagonal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Hyperband brackets\n",
" \n",
"s=2\n",
"n_i r_i\n",
"------------\n",
"9 1.0\n",
"3.0 3.0\n",
"1.0 9.0\n",
" \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"3 3.0\n",
"1.0 9.0\n",
" \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
"3 9\n",
" \n",
"Create superset of MSTs, i.e., i=0 for for each bracket s\n",
" \n",
"s=2\n",
"n=9\n",
"r=1.0\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"Create superset of MSTs, i.e., i=0 for for each bracket s\n",
" \n",
"s=1\n",
"n=3\n",
"r=3.0\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"Create superset of MSTs, i.e., i=0 for for each bracket s\n",
" \n",
"s=0\n",
"n=3\n",
"r=9\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"Hyperband diagonal\n",
"outer loop on diagonal:\n",
" \n",
"i=0\n",
"Done.\n",
"loop on s desc to create diagonal table:\n",
"9 rows affected.\n",
" \n",
"try params for i = 0\n",
"Done.\n"
]
}
],
"source": [
"hp = Hyperband_diagonal(get_params, try_params )\n",
"results = hp.run()\n",
"#hp.n_vals[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"plot\"></a>\n",
"# 6. Plot results"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"from collections import defaultdict\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"sns.set_palette(sns.color_palette(\"hls\", 20))\n",
"plt.rcParams.update({'font.size': 12})\n",
"pd.set_option('display.max_colwidth', -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training dataset"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"999.9999783255842\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
"df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Training metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Training loss curve')\n",
"\n",
"for run_id in df_results['run_id']:\n",
" df_output_info = %sql SELECT training_metrics,training_loss FROM $results_table WHERE run_id = $run_id\n",
" df_output_info = df_output_info.DataFrame()\n",
" training_metrics = df_output_info['training_metrics'][0]\n",
" training_loss = df_output_info['training_loss'][0]\n",
" X = range(len(training_metrics))\n",
" \n",
" ax_metric.plot(X, training_metrics, label=run_id, marker='o')\n",
" ax_loss.plot(X, training_loss, label=run_id, marker='o')\n",
"\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"999.9999783255842\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
"df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"for run_id in df_results['run_id']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM $results_table WHERE run_id = $run_id\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" X = range(len(validation_metrics))\n",
" \n",
" ax_metric.plot(X, validation_metrics, label=run_id, marker='o')\n",
" ax_loss.plot(X, validation_loss, label=run_id, marker='o')\n",
"\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"print\"></a>\n",
"# 7. Print run schedules (display only)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty print reg Hyperband run schedule"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"max_iter = 9\n",
"eta = 3\n",
"B = 3*max_iter = 27\n",
" \n",
"s=2\n",
"n_i r_i\n",
"------------\n",
"9 1.0\n",
"3.0 3.0\n",
"1.0 9.0\n",
" \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"3 3.0\n",
"1.0 9.0\n",
" \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
"3 9\n",
" \n",
"sum of configurations at leaf nodes across all s = 5.0\n",
"(if have more workers than this, they may not be 100% busy)\n"
]
}
],
"source": [
"import numpy as np\n",
"from math import log, ceil\n",
"\n",
"#input\n",
"max_iter = 9 # maximum iterations/epochs per configuration\n",
"eta = 3 # defines downsampling rate (default=3)\n",
"\n",
"logeta = lambda x: log(x)/log(eta)\n",
"s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
"B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
"\n",
"#echo output\n",
"print (\"max_iter = \" + str(max_iter))\n",
"print (\"eta = \" + str(eta))\n",
"print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
"\n",
"sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
"\n",
"#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
"for s in reversed(range(s_max+1)):\n",
" \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
" print (\"------------\")\n",
" counter = 0\n",
" \n",
" n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
" r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" #### Begin Finite Horizon Successive Halving with (n,r)\n",
" #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
" for i in range(s+1):\n",
" # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
" n_i = n*eta**(-i)\n",
" r_i = r*eta**(i)\n",
" \n",
" print (str(n_i) + \" \" + str (r_i))\n",
" \n",
" # check if leaf node for this s\n",
" if counter == s:\n",
" sum_leaf_n_i += n_i\n",
" counter += 1\n",
" \n",
" #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
" #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
" #### End Finite Horizon Successive Halving with (n,r)\n",
"\n",
"print (\" \")\n",
"print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
"print (\"(if have more workers than this, they may not be 100% busy)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty print Hyperband diagonal run schedule"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"echo input:\n",
"max_iter = 9\n",
"eta = 3\n",
"s_max = 2\n",
"B = 3*max_iter = 27\n",
" \n",
"initial n, r values for each s:\n",
"s=2\n",
"n=9\n",
"r=1.0\n",
" \n",
"s=1\n",
"n=3\n",
"r=3.0\n",
" \n",
"s=0\n",
"n=3\n",
"r=9\n",
" \n",
"outer loop on diagonal:\n",
" \n",
"i=0\n",
"inner loop on s desc:\n",
"s=2\n",
"n_i=9\n",
"r_i=1.0\n",
" \n",
"i=1\n",
"inner loop on s desc:\n",
"s=2\n",
"n_i=3.0\n",
"r_i=3.0\n",
"s=1\n",
"n_i=3\n",
"r_i=3.0\n",
" \n",
"i=2\n",
"inner loop on s desc:\n",
"s=2\n",
"n_i=1.0\n",
"r_i=9.0\n",
"s=1\n",
"n_i=1.0\n",
"r_i=9.0\n",
"s=0\n",
"n_i=3\n",
"r_i=9\n"
]
}
],
"source": [
"import numpy as np\n",
"from math import log, ceil\n",
"\n",
"#input\n",
"max_iter = 9 # maximum iterations/epochs per configuration\n",
"eta = 3 # defines downsampling rate (default=3)\n",
"\n",
"logeta = lambda x: log(x)/log(eta)\n",
"s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
"B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
"\n",
"#echo output\n",
"print (\"echo input:\")\n",
"print (\"max_iter = \" + str(max_iter))\n",
"print (\"eta = \" + str(eta))\n",
"print (\"s_max = \" + str(s_max))\n",
"print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
"\n",
"print (\" \")\n",
"print (\"initial n, r values for each s:\")\n",
"initial_n_vals = {}\n",
"initial_r_vals = {}\n",
"# get hyper parameter configs for each s\n",
"for s in reversed(range(s_max+1)):\n",
" \n",
" n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
" r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
" \n",
" initial_n_vals[s] = n \n",
" initial_r_vals[s] = r \n",
" \n",
" print (\"s=\" + str(s))\n",
" print (\"n=\" + str(n))\n",
" print (\"r=\" + str(r))\n",
" print (\" \")\n",
" \n",
"print (\"outer loop on diagonal:\")\n",
"# outer loop on diagonal\n",
"for i in range(s_max+1):\n",
" print (\" \")\n",
" print (\"i=\" + str(i))\n",
" \n",
" print (\"inner loop on s desc:\")\n",
" # inner loop on s desc\n",
" for s in range(s_max, s_max-i-1, -1):\n",
" n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
" r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
" \n",
" print (\"s=\" + str(s))\n",
" print (\"n_i=\" + str(n_i))\n",
" print (\"r_i=\" + str(r_i))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}