blob: cfe8c97c51cdbc70fc7bc5dcfb301079132bd0c4 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Selection for Multilayer Perceptron Using Keras and MADlib\n",
"\n",
"E2E classification example using MADlib calling a Keras MLP for different hyperparameters and model architectures.\n",
"\n",
"Deep learning works best on very large datasets, but that is not convenient for a quick introduction to the syntax. So in this workbook we use the well known iris data set from https://archive.ics.uci.edu/ml/datasets/iris to help get you started. It is similar to the example in user docs http://madlib.apache.org/docs/latest/index.html\n",
"\n",
"For more realistic examples please refer to the deep learning notebooks at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts\n",
"\n",
"## Table of contents\n",
"\n",
"<a href=\"#class\">Classification</a>\n",
"\n",
"* <a href=\"#create_input_data\">1. Create input data</a>\n",
"\n",
"* <a href=\"#pp\">2. Call preprocessor for deep learning</a>\n",
"\n",
"* <a href=\"#load\">3. Define and load model architecture</a>\n",
"\n",
"* <a href=\"#def_mst\">4. Define and load model selection tuples</a>\n",
"\n",
"* <a href=\"#train\">5. Train</a>\n",
"\n",
"* <a href=\"#eval\">6. Evaluate</a>\n",
"\n",
"* <a href=\"#pred\">7. Predict</a>\n",
"\n",
"<a href=\"#class2\">Classification with Other Parameters</a>\n",
"\n",
"* <a href=\"#val_dataset\">1. Validation dataset</a>\n",
"\n",
"* <a href=\"#pred_prob\">2. Predict probabilities</a>\n",
"\n",
"* <a href=\"#warm_start\">3. Warm start</a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
"#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
"\n",
"# Greenplum Database 5.x on GCP - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class\"></a>\n",
"# Classification\n",
"\n",
"<a id=\"create_input_data\"></a>\n",
"# 1. Create input data\n",
"\n",
"Load iris data set."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"150 rows affected.\n",
"150 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>attributes</th>\n",
" <th>class_text</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>[Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>[Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>[Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>[Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>[Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>[Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>[Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>[Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>[Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>[Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>[Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>[Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>[Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>[Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37</td>\n",
" <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39</td>\n",
" <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42</td>\n",
" <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>46</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>47</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>48</td>\n",
" <td>[Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>[Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>[Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>[Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>52</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>54</td>\n",
" <td>[Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>[Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>58</td>\n",
" <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>63</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>65</td>\n",
" <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>66</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>68</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>[Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>[Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>71</td>\n",
" <td>[Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>73</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>74</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>[Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>[Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>[Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>78</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>79</td>\n",
" <td>[Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>[Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>81</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>83</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>[Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>85</td>\n",
" <td>[Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>86</td>\n",
" <td>[Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>87</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>88</td>\n",
" <td>[Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>[Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>91</td>\n",
" <td>[Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>93</td>\n",
" <td>[Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>[Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>95</td>\n",
" <td>[Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>96</td>\n",
" <td>[Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>[Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>[Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>[Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>[Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>104</td>\n",
" <td>[Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>105</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>106</td>\n",
" <td>[Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>[Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>108</td>\n",
" <td>[Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>109</td>\n",
" <td>[Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>110</td>\n",
" <td>[Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>111</td>\n",
" <td>[Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>[Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>[Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>[Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>[Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>118</td>\n",
" <td>[Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>[Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>[Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>122</td>\n",
" <td>[Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>[Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>124</td>\n",
" <td>[Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>126</td>\n",
" <td>[Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>[Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>128</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>129</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>130</td>\n",
" <td>[Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>131</td>\n",
" <td>[Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>132</td>\n",
" <td>[Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>133</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>134</td>\n",
" <td>[Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>135</td>\n",
" <td>[Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>[Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>137</td>\n",
" <td>[Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>138</td>\n",
" <td>[Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>139</td>\n",
" <td>[Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>140</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>141</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>142</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>143</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>[Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>[Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>150</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (2, [Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (3, [Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (4, [Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (5, [Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (6, [Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')], u'Iris-setosa'),\n",
" (7, [Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (8, [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (9, [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (10, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (11, [Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (12, [Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (13, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')], u'Iris-setosa'),\n",
" (14, [Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')], u'Iris-setosa'),\n",
" (15, [Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (16, [Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (17, [Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')], u'Iris-setosa'),\n",
" (18, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (19, [Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')], u'Iris-setosa'),\n",
" (20, [Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')], u'Iris-setosa'),\n",
" (21, [Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')], u'Iris-setosa'),\n",
" (22, [Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (23, [Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')], u'Iris-setosa'),\n",
" (24, [Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')], u'Iris-setosa'),\n",
" (25, [Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')], u'Iris-setosa'),\n",
" (26, [Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (27, [Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')], u'Iris-setosa'),\n",
" (28, [Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (29, [Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (30, [Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (31, [Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (32, [Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (33, [Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (34, [Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (35, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (36, [Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (37, [Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (38, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (39, [Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (40, [Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (41, [Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (42, [Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (43, [Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (44, [Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')], u'Iris-setosa'),\n",
" (45, [Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')], u'Iris-setosa'),\n",
" (46, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (47, [Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (48, [Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (49, [Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (50, [Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (51, [Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (52, [Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (53, [Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (54, [Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (55, [Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (56, [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (57, [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (58, [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (59, [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (60, [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (61, [Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (62, [Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (63, [Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (64, [Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (65, [Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (66, [Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (67, [Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (68, [Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (69, [Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (70, [Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (71, [Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')], u'Iris-versicolor'),\n",
" (72, [Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (73, [Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (74, [Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (75, [Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (76, [Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (77, [Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (78, [Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')], u'Iris-versicolor'),\n",
" (79, [Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (80, [Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (81, [Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (82, [Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (83, [Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (84, [Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (85, [Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (86, [Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (87, [Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (88, [Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (89, [Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (90, [Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (91, [Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (92, [Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (93, [Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (94, [Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (95, [Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (96, [Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (97, [Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (98, [Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (99, [Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (100, [Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (101, [Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')], u'Iris-virginica'),\n",
" (102, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (103, [Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')], u'Iris-virginica'),\n",
" (104, [Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')], u'Iris-virginica'),\n",
" (105, [Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')], u'Iris-virginica'),\n",
" (106, [Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (107, [Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')], u'Iris-virginica'),\n",
" (108, [Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')], u'Iris-virginica'),\n",
" (109, [Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (110, [Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')], u'Iris-virginica'),\n",
" (111, [Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')], u'Iris-virginica'),\n",
" (112, [Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')], u'Iris-virginica'),\n",
" (113, [Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')], u'Iris-virginica'),\n",
" (114, [Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')], u'Iris-virginica'),\n",
" (115, [Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')], u'Iris-virginica'),\n",
" (116, [Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')], u'Iris-virginica'),\n",
" (117, [Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (118, [Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')], u'Iris-virginica'),\n",
" (119, [Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (120, [Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')], u'Iris-virginica'),\n",
" (121, [Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')], u'Iris-virginica'),\n",
" (122, [Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')], u'Iris-virginica'),\n",
" (123, [Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')], u'Iris-virginica'),\n",
" (124, [Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (125, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')], u'Iris-virginica'),\n",
" (126, [Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')], u'Iris-virginica'),\n",
" (127, [Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (128, [Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (129, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (130, [Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')], u'Iris-virginica'),\n",
" (131, [Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (132, [Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')], u'Iris-virginica'),\n",
" (133, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')], u'Iris-virginica'),\n",
" (134, [Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')], u'Iris-virginica'),\n",
" (135, [Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')], u'Iris-virginica'),\n",
" (136, [Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (137, [Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (138, [Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (139, [Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (140, [Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')], u'Iris-virginica'),\n",
" (141, [Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (142, [Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (143, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (144, [Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (145, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')], u'Iris-virginica'),\n",
" (146, [Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')], u'Iris-virginica'),\n",
" (147, [Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')], u'Iris-virginica'),\n",
" (148, [Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')], u'Iris-virginica'),\n",
" (149, [Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')], u'Iris-virginica'),\n",
" (150, [Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')], u'Iris-virginica')]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS iris_data;\n",
"\n",
"CREATE TABLE iris_data(\n",
" id serial,\n",
" attributes numeric[],\n",
" class_text varchar\n",
");\n",
"\n",
"INSERT INTO iris_data(id, attributes, class_text) VALUES\n",
"(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),\n",
"(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),\n",
"(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),\n",
"(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),\n",
"(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),\n",
"(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),\n",
"(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),\n",
"(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),\n",
"(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),\n",
"(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),\n",
"(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),\n",
"(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),\n",
"(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),\n",
"(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),\n",
"(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),\n",
"(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),\n",
"(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),\n",
"(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),\n",
"(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),\n",
"(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),\n",
"(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),\n",
"(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),\n",
"(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),\n",
"(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),\n",
"(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),\n",
"(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),\n",
"(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),\n",
"(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),\n",
"(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),\n",
"(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),\n",
"(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),\n",
"(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),\n",
"(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),\n",
"(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),\n",
"(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),\n",
"(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),\n",
"(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),\n",
"(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),\n",
"(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),\n",
"(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),\n",
"(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),\n",
"(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),\n",
"(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),\n",
"(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),\n",
"(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),\n",
"(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),\n",
"(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),\n",
"(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),\n",
"(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),\n",
"(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),\n",
"(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),\n",
"(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),\n",
"(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),\n",
"(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),\n",
"(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),\n",
"(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),\n",
"(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),\n",
"(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),\n",
"(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),\n",
"(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),\n",
"(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),\n",
"(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),\n",
"(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),\n",
"(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),\n",
"(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),\n",
"(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),\n",
"(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),\n",
"(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),\n",
"(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),\n",
"(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),\n",
"(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),\n",
"(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),\n",
"(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),\n",
"(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),\n",
"(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),\n",
"(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),\n",
"(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),\n",
"(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),\n",
"(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),\n",
"(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),\n",
"(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),\n",
"(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),\n",
"(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),\n",
"(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),\n",
"(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),\n",
"(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),\n",
"(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),\n",
"(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),\n",
"(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),\n",
"(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),\n",
"(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),\n",
"(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),\n",
"(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),\n",
"(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),\n",
"(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),\n",
"(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),\n",
"(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),\n",
"(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),\n",
"(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),\n",
"(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),\n",
"(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),\n",
"(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),\n",
"(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),\n",
"(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),\n",
"(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),\n",
"(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),\n",
"(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),\n",
"(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),\n",
"(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),\n",
"(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),\n",
"(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),\n",
"(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),\n",
"(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),\n",
"(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),\n",
"(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),\n",
"(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),\n",
"(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),\n",
"(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),\n",
"(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),\n",
"(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),\n",
"(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),\n",
"(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),\n",
"(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),\n",
"(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),\n",
"(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),\n",
"(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),\n",
"(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),\n",
"(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),\n",
"(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),\n",
"(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),\n",
"(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),\n",
"(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),\n",
"(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),\n",
"(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),\n",
"(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),\n",
"(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),\n",
"(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),\n",
"(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),\n",
"(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),\n",
"(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),\n",
"(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');\n",
"\n",
"SELECT * FROM iris_data ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a test/validation dataset from the training data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(120L,)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train, iris_test;\n",
"\n",
"-- Set seed so results are reproducible\n",
"SELECT setseed(0);\n",
"\n",
"SELECT madlib.train_test_split('iris_data', -- Source table\n",
" 'iris', -- Output table root name\n",
" 0.8, -- Train proportion\n",
" NULL, -- Test proportion (0.2)\n",
" NULL, -- Strata definition\n",
" NULL, -- Output all columns\n",
" NULL, -- Sample without replacement\n",
" TRUE -- Separate output tables\n",
" );\n",
"\n",
"SELECT COUNT(*) FROM iris_train;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pp\"></a>\n",
"# 2. Call preprocessor for deep learning\n",
"Training dataset (uses training preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>independent_var_shape</th>\n",
" <th>dependent_var_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[60, 4]</td>\n",
" <td>[60, 3]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[60, 4]</td>\n",
" <td>[60, 3]</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([60, 4], [60, 3], 0), ([60, 4], [60, 3], 1)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('iris_train', -- Source table\n",
" 'iris_train_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes' -- Independent variable\n",
" ); \n",
"\n",
"SELECT independent_var_shape, dependent_var_shape, buffer_id FROM iris_train_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train</td>\n",
" <td>iris_train_packed</td>\n",
" <td>class_text</td>\n",
" <td>attributes</td>\n",
" <td>character varying</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>60</td>\n",
" <td>1.0</td>\n",
" <td>3</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train', u'iris_train_packed', u'class_text', u'attributes', u'character varying', [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 60, 1.0, 3, 'all_segments', 'all_segments')]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_train_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset (uses validation preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>independent_var_shape</th>\n",
" <th>dependent_var_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[15, 4]</td>\n",
" <td>[15, 3]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[15, 4]</td>\n",
" <td>[15, 3]</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([15, 4], [15, 3], 0), ([15, 4], [15, 3], 1)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('iris_test', -- Source table\n",
" 'iris_test_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes', -- Independent variable\n",
" 'iris_train_packed' -- From training preprocessor step\n",
" ); \n",
"\n",
"SELECT independent_var_shape, dependent_var_shape, buffer_id FROM iris_test_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_test</td>\n",
" <td>iris_test_packed</td>\n",
" <td>class_text</td>\n",
" <td>attributes</td>\n",
" <td>character varying</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>15</td>\n",
" <td>1.0</td>\n",
" <td>3</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_test', u'iris_test_packed', u'class_text', u'attributes', u'character varying', [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 15, 1.0, 3, 'all_segments', 'all_segments')]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_test_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load\"></a>\n",
"# 3. Define and load model architecture\n",
"Import Keras libraries"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n"
]
}
],
"source": [
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define model architecture with 1 hidden layer:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_1 (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 193\n",
"Trainable params: 193\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model1 = Sequential()\n",
"model1.add(Dense(10, activation='relu', input_shape=(4,)))\n",
"model1.add(Dense(10, activation='relu'))\n",
"model1.add(Dense(3, activation='softmax'))\n",
" \n",
"model1.summary()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model1.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define model architecture with 2 hidden layers:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_4 (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_7 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 303\n",
"Trainable params: 303\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model2 = Sequential()\n",
"model2.add(Dense(10, activation='relu', input_shape=(4,)))\n",
"model2.add(Dense(10, activation='relu'))\n",
"model2.add(Dense(10, activation='relu'))\n",
"model2.add(Dense(3, activation='softmax'))\n",
" \n",
"model2.summary()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_6\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_7\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>model_arch</th>\n",
" <th>model_weights</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>__internal_madlib_id__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>None</td>\n",
" <td>Sophie</td>\n",
" <td>MLP with 1 hidden layer</td>\n",
" <td>__madlib_temp_96702431_1576708421_6956281__</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_4', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_5', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_6', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_7', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>None</td>\n",
" <td>Maria</td>\n",
" <td>MLP with 2 hidden layers</td>\n",
" <td>__madlib_temp_85244704_1576708422_1853942__</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, None, u'Sophie', u'MLP with 1 hidden layer', u'__madlib_temp_96702431_1576708421_6956281__'),\n",
" (2, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_4', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_5', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_6', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_7', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, None, u'Maria', u'MLP with 2 hidden layers', u'__madlib_temp_85244704_1576708422_1853942__')]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS model_arch_library;\n",
"\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table,\n",
" \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Sophie', -- Name\n",
" 'MLP with 1 hidden layer' -- Descr\n",
");\n",
"\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table,\n",
" \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_6\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_7\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Maria', -- Name\n",
" 'MLP with 2 hidden layers' -- Descr\n",
");\n",
"\n",
"SELECT * FROM model_arch_library ORDER BY model_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"def_mst\"></a>\n",
"# 4. Define and load model selection tuples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Select the model(s) from the model architecture table that you want to run, along with the compile and fit parameters. Permutations will be created for the set of model selection parameters will be loaded:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (2, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
" (3, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (4, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
" (5, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (6, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
" (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (8, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
" (9, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (10, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
" (11, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
" (12, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1')]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS mst_table, mst_table_summary;\n",
"\n",
"SELECT madlib.load_model_selection_table('model_arch_library', -- model architecture table\n",
" 'mst_table', -- model selection table output\n",
" ARRAY[1,2], -- model ids from model architecture table\n",
" ARRAY[ -- compile params\n",
" $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,\n",
" $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,\n",
" $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$\n",
" ],\n",
" ARRAY[ -- fit params\n",
" $$batch_size=4,epochs=1$$,\n",
" $$batch_size=8,epochs=1$$\n",
" ]\n",
" );\n",
" \n",
"SELECT * FROM mst_table ORDER BY mst_key;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is the name of the model architecture table that corresponds to the model selection table:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_arch_table</th>\n",
" </tr>\n",
" <tr>\n",
" <td>model_arch_library</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'model_arch_library',)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM mst_table_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"train\"></a>\n",
"# 5. Train\n",
"Train multiple models:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
"\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 10, -- num_iterations\n",
" FALSE -- use gpus\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>None</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>class_text</td>\n",
" <td>attributes</td>\n",
" <td>model_arch_library</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" <td>False</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2019-12-18 22:33:49.706384</td>\n",
" <td>2019-12-18 22:35:34.547961</td>\n",
" <td>1.17-dev</td>\n",
" <td>3</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>character varying</td>\n",
" <td>1.0</td>\n",
" <td>[10]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', None, u'iris_multi_model', u'iris_multi_model_info', u'class_text', u'attributes', u'model_arch_library', 10, 10, False, None, None, datetime.datetime(2019, 12, 18, 22, 33, 49, 706384), datetime.datetime(2019, 12, 18, 22, 35, 34, 547961), u'1.17-dev', 3, [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], u'character varying', 1.0, [10])]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View results for each model:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.148514986038208]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.975000023842</td>\n",
" <td>0.12241948396</td>\n",
" <td>[0.975000023841858]</td>\n",
" <td>[0.122419483959675]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.172315120697021]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.975000023842</td>\n",
" <td>0.123081341386</td>\n",
" <td>[0.975000023841858]</td>\n",
" <td>[0.123081341385841]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.274233102798462]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.925000011921</td>\n",
" <td>0.171397775412</td>\n",
" <td>[0.925000011920929]</td>\n",
" <td>[0.171397775411606]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.155992984771729]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.925000011921</td>\n",
" <td>0.51177251339</td>\n",
" <td>[0.925000011920929]</td>\n",
" <td>[0.511772513389587]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.220170021057129]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.908333361149</td>\n",
" <td>0.214677110314</td>\n",
" <td>[0.908333361148834]</td>\n",
" <td>[0.214677110314369]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.191344022750854]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.833333313465</td>\n",
" <td>0.524632036686</td>\n",
" <td>[0.833333313465118]</td>\n",
" <td>[0.524632036685944]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.181636810302734]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.758333325386</td>\n",
" <td>0.393412530422</td>\n",
" <td>[0.758333325386047]</td>\n",
" <td>[0.393412530422211]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.181061029434204]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.658333361149</td>\n",
" <td>0.474381148815</td>\n",
" <td>[0.658333361148834]</td>\n",
" <td>[0.474381148815155]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.20294713973999]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.658333361149</td>\n",
" <td>0.475430130959</td>\n",
" <td>[0.658333361148834]</td>\n",
" <td>[0.475430130958557]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.207202911376953]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.574999988079</td>\n",
" <td>0.885546028614</td>\n",
" <td>[0.574999988079071]</td>\n",
" <td>[0.885546028614044]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.374184846878052]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.433333337307</td>\n",
" <td>0.82793289423</td>\n",
" <td>[0.433333337306976]</td>\n",
" <td>[0.827932894229889]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.216787099838257]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.316666662693</td>\n",
" <td>1.10255157948</td>\n",
" <td>[0.316666662693024]</td>\n",
" <td>[1.1025515794754]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(4, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.148514986038208], [u'accuracy'], 0.975000023842, 0.12241948396, [0.975000023841858], [0.122419483959675], None, None, None, None),\n",
" (10, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.172315120697021], [u'accuracy'], 0.975000023842, 0.123081341386, [0.975000023841858], [0.123081341385841], None, None, None, None),\n",
" (9, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.274233102798462], [u'accuracy'], 0.925000011921, 0.171397775412, [0.925000011920929], [0.171397775411606], None, None, None, None),\n",
" (5, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.155992984771729], [u'accuracy'], 0.925000011921, 0.51177251339, [0.925000011920929], [0.511772513389587], None, None, None, None),\n",
" (3, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.220170021057129], [u'accuracy'], 0.908333361149, 0.214677110314, [0.908333361148834], [0.214677110314369], None, None, None, None),\n",
" (12, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.191344022750854], [u'accuracy'], 0.833333313465, 0.524632036686, [0.833333313465118], [0.524632036685944], None, None, None, None),\n",
" (8, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.181636810302734], [u'accuracy'], 0.758333325386, 0.393412530422, [0.758333325386047], [0.393412530422211], None, None, None, None),\n",
" (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.181061029434204], [u'accuracy'], 0.658333361149, 0.474381148815, [0.658333361148834], [0.474381148815155], None, None, None, None),\n",
" (2, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.20294713973999], [u'accuracy'], 0.658333361149, 0.475430130959, [0.658333361148834], [0.475430130958557], None, None, None, None),\n",
" (6, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.207202911376953], [u'accuracy'], 0.574999988079, 0.885546028614, [0.574999988079071], [0.885546028614044], None, None, None, None),\n",
" (11, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.374184846878052], [u'accuracy'], 0.433333337307, 0.82793289423, [0.433333337306976], [0.827932894229889], None, None, None, None),\n",
" (1, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.216787099838257], [u'accuracy'], 0.316666662693, 1.10255157948, [0.316666662693024], [1.1025515794754], None, None, None, None)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY training_metrics_final DESC, training_loss_final;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"eval\"></a>\n",
"# 6. Evaluate\n",
"\n",
"Now run evaluate using model we built above:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>loss</th>\n",
" <th>metric</th>\n",
" <th>metrics_type</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.15500420332</td>\n",
" <td>0.966666638851</td>\n",
" <td>[u'accuracy']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.15500420331955, 0.966666638851166, [u'accuracy'])]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_validate;\n",
"SELECT madlib.madlib_keras_evaluate('iris_multi_model', -- model\n",
" 'iris_test_packed', -- test table\n",
" 'iris_validate', -- output table\n",
" NULL, -- use gpus\n",
" 3 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_validate;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred\"></a>\n",
"# 7. Predict\n",
"\n",
"Now predict using model we built. We will use the validation data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"30 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>estimated_class_text</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>118</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>122</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>132</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3, u'Iris-setosa'),\n",
" (5, u'Iris-setosa'),\n",
" (7, u'Iris-setosa'),\n",
" (8, u'Iris-setosa'),\n",
" (10, u'Iris-setosa'),\n",
" (19, u'Iris-setosa'),\n",
" (25, u'Iris-setosa'),\n",
" (26, u'Iris-setosa'),\n",
" (28, u'Iris-setosa'),\n",
" (38, u'Iris-setosa'),\n",
" (44, u'Iris-setosa'),\n",
" (45, u'Iris-setosa'),\n",
" (51, u'Iris-versicolor'),\n",
" (53, u'Iris-versicolor'),\n",
" (57, u'Iris-versicolor'),\n",
" (59, u'Iris-versicolor'),\n",
" (62, u'Iris-versicolor'),\n",
" (69, u'Iris-virginica'),\n",
" (75, u'Iris-versicolor'),\n",
" (77, u'Iris-versicolor'),\n",
" (97, u'Iris-versicolor'),\n",
" (102, u'Iris-virginica'),\n",
" (107, u'Iris-virginica'),\n",
" (114, u'Iris-virginica'),\n",
" (118, u'Iris-virginica'),\n",
" (120, u'Iris-virginica'),\n",
" (122, u'Iris-virginica'),\n",
" (132, u'Iris-virginica'),\n",
" (146, u'Iris-virginica'),\n",
" (147, u'Iris-virginica')]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_multi_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict', -- output table\n",
" 'response', -- prediction type\n",
" FALSE, -- use gpus\n",
" 3 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1L,)]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id) \n",
"WHERE iris_predict.estimated_class_text != iris_test.class_text;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Percent missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>test_accuracy_percent</th>\n",
" </tr>\n",
" <tr>\n",
" <td>96.67</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('96.67'),)]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from\n",
" (select iris_test.class_text as actual, iris_predict.estimated_class_text as estimated\n",
" from iris_predict inner join iris_test\n",
" on iris_test.id=iris_predict.id) q\n",
"WHERE q.actual=q.estimated;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class2\"></a>\n",
"# Classification with Other Parameters\n",
"\n",
"<a id=\"val_dataset\"></a>\n",
"# 1. Validation dataset\n",
"\n",
"Now use a validation dataset and compute metrics every 2nd iteration using the 'metrics_compute_frequency' parameter. This can help reduce run time if you do not need metrics computed at every iteration."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
"\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 10, -- num_iterations\n",
" FALSE, -- use gpus\n",
" 'iris_test_packed', -- validation dataset\n",
" 3, -- metrics compute frequency\n",
" FALSE, -- warm start\n",
" 'Sophie L.', -- name\n",
" 'Model selection for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_test_packed</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>class_text</td>\n",
" <td>attributes</td>\n",
" <td>model_arch_library</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>Sophie L.</td>\n",
" <td>Model selection for iris dataset</td>\n",
" <td>2019-12-18 22:35:49.962345</td>\n",
" <td>2019-12-18 22:37:51.230499</td>\n",
" <td>1.17-dev</td>\n",
" <td>3</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>character varying</td>\n",
" <td>1.0</td>\n",
" <td>[3, 6, 9, 10]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_test_packed', u'iris_multi_model', u'iris_multi_model_info', u'class_text', u'attributes', u'model_arch_library', 10, 3, False, u'Sophie L.', u'Model selection for iris dataset', datetime.datetime(2019, 12, 18, 22, 35, 49, 962345), datetime.datetime(2019, 12, 18, 22, 37, 51, 230499), u'1.17-dev', 3, [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], u'character varying', 1.0, [3, 6, 9, 10])]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View performance of each model:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.400555849075317, 0.175060987472534, 0.161082029342651, 0.159379005432129]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.958333313465</td>\n",
" <td>0.370426625013</td>\n",
" <td>[0.841666638851166, 0.875, 0.958333313465118, 0.958333313465118]</td>\n",
" <td>[0.597030103206635, 0.467845916748047, 0.394165992736816, 0.370426625013351]</td>\n",
" <td>1.0</td>\n",
" <td>0.32715767622</td>\n",
" <td>[0.866666674613953, 0.933333337306976, 1.0, 1.0]</td>\n",
" <td>[0.587784588336945, 0.432697623968124, 0.352933287620544, 0.32715767621994]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.157984018325806, 0.146160840988159, 0.446839094161987, 0.217149972915649]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.916666686535</td>\n",
" <td>0.176682218909</td>\n",
" <td>[0.958333313465118, 0.891666650772095, 0.841666638851166, 0.916666686534882]</td>\n",
" <td>[0.340974450111389, 0.224177747964859, 0.315857976675034, 0.176682218909264]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.146555349231</td>\n",
" <td>[0.966666638851166, 0.933333337306976, 0.866666674613953, 0.966666638851166]</td>\n",
" <td>[0.306026995182037, 0.204480707645416, 0.291850447654724, 0.146555349230766]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.158334016799927, 0.492121934890747, 0.168816804885864, 0.160614013671875]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.949999988079</td>\n",
" <td>0.137093007565</td>\n",
" <td>[0.75, 0.808333337306976, 0.941666662693024, 0.949999988079071]</td>\n",
" <td>[0.861838400363922, 0.306531131267548, 0.267581582069397, 0.137093007564545]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.0812632590532</td>\n",
" <td>[0.533333361148834, 0.733333349227905, 1.0, 0.966666638851166]</td>\n",
" <td>[1.17265951633453, 0.347328811883926, 0.0795030668377876, 0.0812632590532303]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.206979990005493, 0.175852060317993, 0.18351411819458, 0.173283100128174]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.841666638851</td>\n",
" <td>0.319059103727</td>\n",
" <td>[0.833333313465118, 0.916666686534882, 0.958333313465118, 0.841666638851166]</td>\n",
" <td>[0.375581055879593, 0.235803470015526, 0.119093284010887, 0.319059103727341]</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.294114112854</td>\n",
" <td>[0.866666674613953, 0.966666638851166, 0.933333337306976, 0.866666674613953]</td>\n",
" <td>[0.332203418016434, 0.206457450985909, 0.09817935526371, 0.294114112854004]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.154335021972656, 0.14276385307312, 0.160094022750854, 0.147177934646606]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.833333313465</td>\n",
" <td>0.315035998821</td>\n",
" <td>[0.850000023841858, 0.966666638851166, 0.966666638851166, 0.833333313465118]</td>\n",
" <td>[0.39260533452034, 0.207864001393318, 0.14202418923378, 0.315035998821259]</td>\n",
" <td>0.833333313465</td>\n",
" <td>0.287047833204</td>\n",
" <td>[0.833333313465118, 0.966666638851166, 0.933333337306976, 0.833333313465118]</td>\n",
" <td>[0.350265830755234, 0.179627984762192, 0.119969591498375, 0.287047833204269]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.183771848678589, 0.442173957824707, 0.196517944335938, 0.183962106704712]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.683333337307</td>\n",
" <td>0.773626208305</td>\n",
" <td>[0.983333349227905, 0.783333361148834, 0.841666638851166, 0.683333337306976]</td>\n",
" <td>[0.323956668376923, 0.355609774589539, 0.289077579975128, 0.773626208305359]</td>\n",
" <td>0.733333349228</td>\n",
" <td>0.598832905293</td>\n",
" <td>[0.966666638851166, 0.733333349227905, 0.866666674613953, 0.733333349227905]</td>\n",
" <td>[0.292185336351395, 0.310099214315414, 0.278687566518784, 0.598832905292511]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.215842962265015, 0.183883190155029, 0.181258201599121, 0.233398914337158]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.658333361149</td>\n",
" <td>0.501300632954</td>\n",
" <td>[0.341666668653488, 0.658333361148834, 0.658333361148834, 0.658333361148834]</td>\n",
" <td>[0.947986364364624, 0.807084918022156, 0.549242556095123, 0.501300632953644]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.459856539965</td>\n",
" <td>[0.300000011920929, 0.699999988079071, 0.699999988079071, 0.699999988079071]</td>\n",
" <td>[0.971994161605835, 0.821518063545227, 0.513974606990814, 0.459856539964676]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.181059837341309, 0.156504154205322, 0.154800891876221, 0.165037870407104]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.675000011921</td>\n",
" <td>0.500130057335</td>\n",
" <td>[0.658333361148834, 0.908333361148834, 0.908333361148834, 0.675000011920929]</td>\n",
" <td>[0.822371363639832, 0.354260504245758, 0.206746637821198, 0.5001300573349]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.511800050735</td>\n",
" <td>[0.699999988079071, 0.933333337306976, 0.966666638851166, 0.699999988079071]</td>\n",
" <td>[0.784473180770874, 0.314396589994431, 0.171932756900787, 0.511800050735474]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.16503119468689, 0.165420055389404, 0.163087844848633, 0.157285213470459]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.600000023842</td>\n",
" <td>0.536593079567</td>\n",
" <td>[0.625, 0.491666674613953, 0.508333325386047, 0.600000023841858]</td>\n",
" <td>[0.877406716346741, 0.665770947933197, 0.563206613063812, 0.536593079566956]</td>\n",
" <td>0.600000023842</td>\n",
" <td>0.50565046072</td>\n",
" <td>[0.566666662693024, 0.533333361148834, 0.600000023841858, 0.600000023841858]</td>\n",
" <td>[0.898801684379578, 0.642534494400024, 0.529698371887207, 0.505650460720062]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.180193901062012, 0.230684041976929, 0.202606916427612, 0.182677030563354]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.5</td>\n",
" <td>1.01774513721</td>\n",
" <td>[0.341666668653488, 0.491666674613953, 0.524999976158142, 0.5]</td>\n",
" <td>[1.10608339309692, 1.06158423423767, 1.02908384799957, 1.01774513721466]</td>\n",
" <td>0.5</td>\n",
" <td>1.01636135578</td>\n",
" <td>[0.300000011920929, 0.466666668653488, 0.466666668653488, 0.5]</td>\n",
" <td>[1.10331404209137, 1.05365967750549, 1.02413082122803, 1.01636135578156]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.181950092315674, 0.197594881057739, 0.187069177627563, 0.183701992034912]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.316666662693</td>\n",
" <td>1.10080897808</td>\n",
" <td>[0.316666662693024, 0.341666668653488, 0.341666668653488, 0.316666662693024]</td>\n",
" <td>[1.1043815612793, 1.11140048503876, 1.09834468364716, 1.10080897808075]</td>\n",
" <td>0.40000000596</td>\n",
" <td>1.09380173683</td>\n",
" <td>[0.400000005960464, 0.300000011920929, 0.300000011920929, 0.400000005960464]</td>\n",
" <td>[1.09075009822845, 1.09998726844788, 1.10155093669891, 1.09380173683167]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.182392835617065, 0.206873893737793, 0.192094087600708, 0.185320854187012]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.341666668653</td>\n",
" <td>1.10410153866</td>\n",
" <td>[0.341666668653488, 0.316666662693024, 0.341666668653488, 0.341666668653488]</td>\n",
" <td>[1.10291886329651, 1.10132431983948, 1.10635650157928, 1.10410153865814]</td>\n",
" <td>0.300000011921</td>\n",
" <td>1.10918176174</td>\n",
" <td>[0.300000011920929, 0.400000005960464, 0.300000011920929, 0.300000011920929]</td>\n",
" <td>[1.10382485389709, 1.09316170215607, 1.1332186460495, 1.10918176174164]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.400555849075317, 0.175060987472534, 0.161082029342651, 0.159379005432129], [u'accuracy'], 0.958333313465, 0.370426625013, [0.841666638851166, 0.875, 0.958333313465118, 0.958333313465118], [0.597030103206635, 0.467845916748047, 0.394165992736816, 0.370426625013351], 1.0, 0.32715767622, [0.866666674613953, 0.933333337306976, 1.0, 1.0], [0.587784588336945, 0.432697623968124, 0.352933287620544, 0.32715767621994]),\n",
" (3, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.157984018325806, 0.146160840988159, 0.446839094161987, 0.217149972915649], [u'accuracy'], 0.916666686535, 0.176682218909, [0.958333313465118, 0.891666650772095, 0.841666638851166, 0.916666686534882], [0.340974450111389, 0.224177747964859, 0.315857976675034, 0.176682218909264], 0.966666638851, 0.146555349231, [0.966666638851166, 0.933333337306976, 0.866666674613953, 0.966666638851166], [0.306026995182037, 0.204480707645416, 0.291850447654724, 0.146555349230766]),\n",
" (1, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.158334016799927, 0.492121934890747, 0.168816804885864, 0.160614013671875], [u'accuracy'], 0.949999988079, 0.137093007565, [0.75, 0.808333337306976, 0.941666662693024, 0.949999988079071], [0.861838400363922, 0.306531131267548, 0.267581582069397, 0.137093007564545], 0.966666638851, 0.0812632590532, [0.533333361148834, 0.733333349227905, 1.0, 0.966666638851166], [1.17265951633453, 0.347328811883926, 0.0795030668377876, 0.0812632590532303]),\n",
" (10, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.206979990005493, 0.175852060317993, 0.18351411819458, 0.173283100128174], [u'accuracy'], 0.841666638851, 0.319059103727, [0.833333313465118, 0.916666686534882, 0.958333313465118, 0.841666638851166], [0.375581055879593, 0.235803470015526, 0.119093284010887, 0.319059103727341], 0.866666674614, 0.294114112854, [0.866666674613953, 0.966666638851166, 0.933333337306976, 0.866666674613953], [0.332203418016434, 0.206457450985909, 0.09817935526371, 0.294114112854004]),\n",
" (4, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.154335021972656, 0.14276385307312, 0.160094022750854, 0.147177934646606], [u'accuracy'], 0.833333313465, 0.315035998821, [0.850000023841858, 0.966666638851166, 0.966666638851166, 0.833333313465118], [0.39260533452034, 0.207864001393318, 0.14202418923378, 0.315035998821259], 0.833333313465, 0.287047833204, [0.833333313465118, 0.966666638851166, 0.933333337306976, 0.833333313465118], [0.350265830755234, 0.179627984762192, 0.119969591498375, 0.287047833204269]),\n",
" (9, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.183771848678589, 0.442173957824707, 0.196517944335938, 0.183962106704712], [u'accuracy'], 0.683333337307, 0.773626208305, [0.983333349227905, 0.783333361148834, 0.841666638851166, 0.683333337306976], [0.323956668376923, 0.355609774589539, 0.289077579975128, 0.773626208305359], 0.733333349228, 0.598832905293, [0.966666638851166, 0.733333349227905, 0.866666674613953, 0.733333349227905], [0.292185336351395, 0.310099214315414, 0.278687566518784, 0.598832905292511]),\n",
" (11, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.215842962265015, 0.183883190155029, 0.181258201599121, 0.233398914337158], [u'accuracy'], 0.658333361149, 0.501300632954, [0.341666668653488, 0.658333361148834, 0.658333361148834, 0.658333361148834], [0.947986364364624, 0.807084918022156, 0.549242556095123, 0.501300632953644], 0.699999988079, 0.459856539965, [0.300000011920929, 0.699999988079071, 0.699999988079071, 0.699999988079071], [0.971994161605835, 0.821518063545227, 0.513974606990814, 0.459856539964676]),\n",
" (2, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.181059837341309, 0.156504154205322, 0.154800891876221, 0.165037870407104], [u'accuracy'], 0.675000011921, 0.500130057335, [0.658333361148834, 0.908333361148834, 0.908333361148834, 0.675000011920929], [0.822371363639832, 0.354260504245758, 0.206746637821198, 0.5001300573349], 0.699999988079, 0.511800050735, [0.699999988079071, 0.933333337306976, 0.966666638851166, 0.699999988079071], [0.784473180770874, 0.314396589994431, 0.171932756900787, 0.511800050735474]),\n",
" (5, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.16503119468689, 0.165420055389404, 0.163087844848633, 0.157285213470459], [u'accuracy'], 0.600000023842, 0.536593079567, [0.625, 0.491666674613953, 0.508333325386047, 0.600000023841858], [0.877406716346741, 0.665770947933197, 0.563206613063812, 0.536593079566956], 0.600000023842, 0.50565046072, [0.566666662693024, 0.533333361148834, 0.600000023841858, 0.600000023841858], [0.898801684379578, 0.642534494400024, 0.529698371887207, 0.505650460720062]),\n",
" (12, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.180193901062012, 0.230684041976929, 0.202606916427612, 0.182677030563354], [u'accuracy'], 0.5, 1.01774513721, [0.341666668653488, 0.491666674613953, 0.524999976158142, 0.5], [1.10608339309692, 1.06158423423767, 1.02908384799957, 1.01774513721466], 0.5, 1.01636135578, [0.300000011920929, 0.466666668653488, 0.466666668653488, 0.5], [1.10331404209137, 1.05365967750549, 1.02413082122803, 1.01636135578156]),\n",
" (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.181950092315674, 0.197594881057739, 0.187069177627563, 0.183701992034912], [u'accuracy'], 0.316666662693, 1.10080897808, [0.316666662693024, 0.341666668653488, 0.341666668653488, 0.316666662693024], [1.1043815612793, 1.11140048503876, 1.09834468364716, 1.10080897808075], 0.40000000596, 1.09380173683, [0.400000005960464, 0.300000011920929, 0.300000011920929, 0.400000005960464], [1.09075009822845, 1.09998726844788, 1.10155093669891, 1.09380173683167]),\n",
" (8, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.182392835617065, 0.206873893737793, 0.192094087600708, 0.185320854187012], [u'accuracy'], 0.341666668653, 1.10410153866, [0.341666668653488, 0.316666662693024, 0.341666668653488, 0.341666668653488], [1.10291886329651, 1.10132431983948, 1.10635650157928, 1.10410153865814], 0.300000011921, 1.10918176174, [0.300000011920929, 0.400000005960464, 0.300000011920929, 0.300000011920929], [1.10382485389709, 1.09316170215607, 1.1332186460495, 1.10918176174164])]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot validation results"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"from collections import defaultdict\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"sns.set_palette(sns.color_palette(\"hls\", 20))\n",
"plt.rcParams.update({'font.size': 12})\n",
"pd.set_option('display.max_colwidth', -1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"1000\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x12e9ae7d0>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"df_summary = %sql SELECT * FROM iris_multi_model_summary;\n",
"df_summary = df_summary.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"iters = df_summary['metrics_iters'][0]\n",
"\n",
"for mst_key in df_results['mst_key']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" \n",
" ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')\n",
" ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')\n",
"\n",
"plt.legend()\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred_prob\"></a>\n",
"# 2. Predict probabilities\n",
"\n",
"Predict with probabilities for each class:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"30 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>prob_Iris-setosa</th>\n",
" <th>prob_Iris-versicolor</th>\n",
" <th>prob_Iris-virginica</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.9999416</td>\n",
" <td>5.8360623e-05</td>\n",
" <td>3.9093355e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.99998116</td>\n",
" <td>1.8880675e-05</td>\n",
" <td>2.5342377e-13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.99994814</td>\n",
" <td>5.1881765e-05</td>\n",
" <td>2.5964983e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.99996114</td>\n",
" <td>3.8810744e-05</td>\n",
" <td>1.176443e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.99992573</td>\n",
" <td>7.4317446e-05</td>\n",
" <td>5.4237942e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>0.9999845</td>\n",
" <td>1.5514812e-05</td>\n",
" <td>1.034207e-13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>0.99992156</td>\n",
" <td>7.845682e-05</td>\n",
" <td>3.7364413e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>0.9998591</td>\n",
" <td>0.00014085071</td>\n",
" <td>2.0146884e-11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>0.9999734</td>\n",
" <td>2.6542659e-05</td>\n",
" <td>4.8342347e-13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>0.99992573</td>\n",
" <td>7.4317446e-05</td>\n",
" <td>5.4237942e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>0.99990726</td>\n",
" <td>9.278052e-05</td>\n",
" <td>6.9040372e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>0.999964</td>\n",
" <td>3.6013742e-05</td>\n",
" <td>5.7615945e-13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>0.00025041687</td>\n",
" <td>0.99780566</td>\n",
" <td>0.0019439155</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>1.843269e-05</td>\n",
" <td>0.9889865</td>\n",
" <td>0.010995116</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>2.4158675e-05</td>\n",
" <td>0.99005336</td>\n",
" <td>0.00992243</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>0.00011159414</td>\n",
" <td>0.9942708</td>\n",
" <td>0.0056176083</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>0.00014697485</td>\n",
" <td>0.99189115</td>\n",
" <td>0.007961868</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>8.6406266e-07</td>\n",
" <td>0.6961896</td>\n",
" <td>0.30380967</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>0.0005239165</td>\n",
" <td>0.9965855</td>\n",
" <td>0.0028905326</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>1.5155997e-05</td>\n",
" <td>0.97978914</td>\n",
" <td>0.020195633</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>0.00023696794</td>\n",
" <td>0.9938279</td>\n",
" <td>0.005935215</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>1.3247301e-09</td>\n",
" <td>0.18419608</td>\n",
" <td>0.8158039</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>2.5100556e-08</td>\n",
" <td>0.30281228</td>\n",
" <td>0.69718766</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>3.2222575e-10</td>\n",
" <td>0.08682407</td>\n",
" <td>0.913176</td>\n",
" </tr>\n",
" <tr>\n",
" <td>118</td>\n",
" <td>5.33606e-11</td>\n",
" <td>0.34179842</td>\n",
" <td>0.6582016</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>9.134116e-09</td>\n",
" <td>0.27099058</td>\n",
" <td>0.72900945</td>\n",
" </tr>\n",
" <tr>\n",
" <td>122</td>\n",
" <td>2.9710499e-09</td>\n",
" <td>0.21993305</td>\n",
" <td>0.7800669</td>\n",
" </tr>\n",
" <tr>\n",
" <td>132</td>\n",
" <td>5.2177818e-09</td>\n",
" <td>0.8370931</td>\n",
" <td>0.16290687</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>1.4404147e-09</td>\n",
" <td>0.2293714</td>\n",
" <td>0.7706286</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>3.8019614e-09</td>\n",
" <td>0.2240861</td>\n",
" <td>0.77591395</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3, 0.9999416, 5.8360623e-05, 3.9093355e-12),\n",
" (5, 0.99998116, 1.8880675e-05, 2.5342377e-13),\n",
" (7, 0.99994814, 5.1881765e-05, 2.5964983e-12),\n",
" (8, 0.99996114, 3.8810744e-05, 1.176443e-12),\n",
" (10, 0.99992573, 7.4317446e-05, 5.4237942e-12),\n",
" (19, 0.9999845, 1.5514812e-05, 1.034207e-13),\n",
" (25, 0.99992156, 7.845682e-05, 3.7364413e-12),\n",
" (26, 0.9998591, 0.00014085071, 2.0146884e-11),\n",
" (28, 0.9999734, 2.6542659e-05, 4.8342347e-13),\n",
" (38, 0.99992573, 7.4317446e-05, 5.4237942e-12),\n",
" (44, 0.99990726, 9.278052e-05, 6.9040372e-12),\n",
" (45, 0.999964, 3.6013742e-05, 5.7615945e-13),\n",
" (51, 0.00025041687, 0.99780566, 0.0019439155),\n",
" (53, 1.843269e-05, 0.9889865, 0.010995116),\n",
" (57, 2.4158675e-05, 0.99005336, 0.00992243),\n",
" (59, 0.00011159414, 0.9942708, 0.0056176083),\n",
" (62, 0.00014697485, 0.99189115, 0.007961868),\n",
" (69, 8.6406266e-07, 0.6961896, 0.30380967),\n",
" (75, 0.0005239165, 0.9965855, 0.0028905326),\n",
" (77, 1.5155997e-05, 0.97978914, 0.020195633),\n",
" (97, 0.00023696794, 0.9938279, 0.005935215),\n",
" (102, 1.3247301e-09, 0.18419608, 0.8158039),\n",
" (107, 2.5100556e-08, 0.30281228, 0.69718766),\n",
" (114, 3.2222575e-10, 0.08682407, 0.913176),\n",
" (118, 5.33606e-11, 0.34179842, 0.6582016),\n",
" (120, 9.134116e-09, 0.27099058, 0.72900945),\n",
" (122, 2.9710499e-09, 0.21993305, 0.7800669),\n",
" (132, 5.2177818e-09, 0.8370931, 0.16290687),\n",
" (146, 1.4404147e-09, 0.2293714, 0.7706286),\n",
" (147, 3.8019614e-09, 0.2240861, 0.77591395)]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_multi_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict', -- output table\n",
" 'prob', -- prediction type\n",
" FALSE, -- use gpus\n",
" 3 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"warm_start\"></a>\n",
"# 3. Warm start\n",
"\n",
"Next, use the warm_start parameter to continue learning, using the coefficients from the run above. Note that we don't drop the model table or model summary table:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 3, -- num_iterations\n",
" FALSE, -- use gpus\n",
" 'iris_test_packed', -- validation dataset\n",
" 1, -- metrics compute frequency\n",
" TRUE, -- warm start\n",
" 'Sophie L.', -- name\n",
" 'Simple MLP for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View summary:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_test_packed</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>class_text</td>\n",
" <td>attributes</td>\n",
" <td>model_arch_library</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>True</td>\n",
" <td>Sophie L.</td>\n",
" <td>Simple MLP for iris dataset</td>\n",
" <td>2019-12-18 22:37:57.948805</td>\n",
" <td>2019-12-18 22:38:43.967187</td>\n",
" <td>1.17-dev</td>\n",
" <td>3</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>character varying</td>\n",
" <td>1.0</td>\n",
" <td>[1, 2, 3]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_test_packed', u'iris_multi_model', u'iris_multi_model_info', u'class_text', u'attributes', u'model_arch_library', 3, 1, True, u'Sophie L.', u'Simple MLP for iris dataset', datetime.datetime(2019, 12, 18, 22, 37, 57, 948805), datetime.datetime(2019, 12, 18, 22, 38, 43, 967187), u'1.17-dev', 3, [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], u'character varying', 1.0, [1, 2, 3])]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View performance of each model:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.17091703414917, 0.163390159606934, 0.155634164810181]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.958333313465</td>\n",
" <td>0.31917694211</td>\n",
" <td>[0.958333313465118, 0.958333313465118, 0.958333313465118]</td>\n",
" <td>[0.348434448242188, 0.334388434886932, 0.319176942110062]</td>\n",
" <td>1.0</td>\n",
" <td>0.272621482611</td>\n",
" <td>[1.0, 1.0, 1.0]</td>\n",
" <td>[0.306039541959763, 0.28966349363327, 0.272621482610703]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.172316074371338, 0.188217163085938, 0.503840208053589]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.193531006575</td>\n",
" <td>[0.958333313465118, 0.925000011920929, 0.899999976158142]</td>\n",
" <td>[0.147025644779205, 0.144938006997108, 0.193531006574631]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.153077676892</td>\n",
" <td>[0.966666638851166, 0.966666638851166, 0.966666638851166]</td>\n",
" <td>[0.132363379001617, 0.116448685526848, 0.153077676892281]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.147105932235718, 0.158121824264526, 0.174723863601685]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.100400544703</td>\n",
" <td>[0.966666638851166, 0.908333361148834, 0.966666638851166]</td>\n",
" <td>[0.112152323126793, 0.197978660464287, 0.100400544703007]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.0844493880868</td>\n",
" <td>[0.933333337306976, 0.966666638851166, 0.966666638851166]</td>\n",
" <td>[0.0945712551474571, 0.170254677534103, 0.0844493880867958]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.224463939666748, 0.412797927856445, 0.193319797515869]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.958333313465</td>\n",
" <td>0.139601364732</td>\n",
" <td>[0.966666638851166, 0.966666638851166, 0.958333313465118]</td>\n",
" <td>[0.122705578804016, 0.0809410735964775, 0.139601364731789]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.131209135056</td>\n",
" <td>[0.966666638851166, 0.966666638851166, 0.966666638851166]</td>\n",
" <td>[0.115778811275959, 0.0698963403701782, 0.131209135055542]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.160850048065186, 0.224483013153076, 0.163106918334961]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.0839553326368</td>\n",
" <td>[0.966666638851166, 0.908333361148834, 0.966666638851166]</td>\n",
" <td>[0.124577566981316, 0.196399554610252, 0.0839553326368332]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.074150800705</td>\n",
" <td>[0.966666638851166, 0.866666674613953, 0.966666638851166]</td>\n",
" <td>[0.137340381741524, 0.232466518878937, 0.0741508007049561]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.14374303817749, 0.154287099838257, 0.17367696762085]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.0860244855285</td>\n",
" <td>[0.966666638851166, 0.841666638851166, 0.966666638851166]</td>\n",
" <td>[0.0824147835373878, 0.337884455919266, 0.0860244855284691]</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.0704526007175</td>\n",
" <td>[0.966666638851166, 0.866666674613953, 0.933333337306976]</td>\n",
" <td>[0.0690516456961632, 0.295713990926743, 0.0704526007175446]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.155812978744507, 0.158360004425049, 0.159363031387329]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.833333313465</td>\n",
" <td>0.344228476286</td>\n",
" <td>[0.666666686534882, 0.675000011920929, 0.833333313465118]</td>\n",
" <td>[1.01126325130463, 1.33927237987518, 0.344228476285934]</td>\n",
" <td>0.800000011921</td>\n",
" <td>0.305708706379</td>\n",
" <td>[0.699999988079071, 0.699999988079071, 0.800000011920929]</td>\n",
" <td>[1.02303433418274, 1.36952638626099, 0.305708706378937]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.187958955764771, 0.186024904251099, 0.501762866973877]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.725000023842</td>\n",
" <td>0.423261642456</td>\n",
" <td>[0.658333361148834, 0.658333361148834, 0.725000023841858]</td>\n",
" <td>[0.46866175532341, 0.445532470941544, 0.423261642456055]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.378630697727</td>\n",
" <td>[0.699999988079071, 0.699999988079071, 0.699999988079071]</td>\n",
" <td>[0.422465175390244, 0.398104608058929, 0.378630697727203]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>[0.176413059234619, 0.169157981872559, 0.15624213218689]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.675000011921</td>\n",
" <td>0.470171242952</td>\n",
" <td>[0.641666650772095, 0.658333361148834, 0.675000011920929]</td>\n",
" <td>[0.504463493824005, 0.486825525760651, 0.470171242952347]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.436036229134</td>\n",
" <td>[0.699999988079071, 0.699999988079071, 0.699999988079071]</td>\n",
" <td>[0.470719456672668, 0.452698260545731, 0.436036229133606]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.164397954940796, 0.486438035964966, 0.192479133605957]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.550000011921</td>\n",
" <td>0.975017726421</td>\n",
" <td>[0.508333325386047, 0.533333361148834, 0.550000011920929]</td>\n",
" <td>[1.00239539146423, 0.986684203147888, 0.975017726421356]</td>\n",
" <td>0.466666668653</td>\n",
" <td>0.981434583664</td>\n",
" <td>[0.5, 0.466666668653488, 0.466666668653488]</td>\n",
" <td>[1.00223970413208, 0.989481270313263, 0.98143458366394]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=8,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.467766046524048, 0.198179006576538, 0.186810970306396]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.341666668653</td>\n",
" <td>1.10613942146</td>\n",
" <td>[0.316666662693024, 0.316666662693024, 0.341666668653488]</td>\n",
" <td>[1.1275190114975, 1.10920584201813, 1.10613942146301]</td>\n",
" <td>0.300000011921</td>\n",
" <td>1.10817503929</td>\n",
" <td>[0.400000005960464, 0.400000005960464, 0.300000011920929]</td>\n",
" <td>[1.10070872306824, 1.09047472476959, 1.10817503929138]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']</td>\n",
" <td>batch_size=4,epochs=1</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.2197265625</td>\n",
" <td>[0.467660903930664, 0.195011138916016, 0.185934066772461]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.341666668653</td>\n",
" <td>1.10524618626</td>\n",
" <td>[0.316666662693024, 0.341666668653488, 0.341666668653488]</td>\n",
" <td>[1.10246300697327, 1.09976887702942, 1.10524618625641]</td>\n",
" <td>0.300000011921</td>\n",
" <td>1.10809886456</td>\n",
" <td>[0.400000005960464, 0.300000011920929, 0.300000011920929]</td>\n",
" <td>[1.09229254722595, 1.09808218479156, 1.10809886455536]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.17091703414917, 0.163390159606934, 0.155634164810181], [u'accuracy'], 0.958333313465, 0.31917694211, [0.958333313465118, 0.958333313465118, 0.958333313465118], [0.348434448242188, 0.334388434886932, 0.319176942110062], 1.0, 0.272621482611, [1.0, 1.0, 1.0], [0.306039541959763, 0.28966349363327, 0.272621482610703]),\n",
" (10, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.172316074371338, 0.188217163085938, 0.503840208053589], [u'accuracy'], 0.899999976158, 0.193531006575, [0.958333313465118, 0.925000011920929, 0.899999976158142], [0.147025644779205, 0.144938006997108, 0.193531006574631], 0.966666638851, 0.153077676892, [0.966666638851166, 0.966666638851166, 0.966666638851166], [0.132363379001617, 0.116448685526848, 0.153077676892281]),\n",
" (4, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.147105932235718, 0.158121824264526, 0.174723863601685], [u'accuracy'], 0.966666638851, 0.100400544703, [0.966666638851166, 0.908333361148834, 0.966666638851166], [0.112152323126793, 0.197978660464287, 0.100400544703007], 0.966666638851, 0.0844493880868, [0.933333337306976, 0.966666638851166, 0.966666638851166], [0.0945712551474571, 0.170254677534103, 0.0844493880867958]),\n",
" (9, 2, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.224463939666748, 0.412797927856445, 0.193319797515869], [u'accuracy'], 0.958333313465, 0.139601364732, [0.966666638851166, 0.966666638851166, 0.958333313465118], [0.122705578804016, 0.0809410735964775, 0.139601364731789], 0.966666638851, 0.131209135056, [0.966666638851166, 0.966666638851166, 0.966666638851166], [0.115778811275959, 0.0698963403701782, 0.131209135055542]),\n",
" (1, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.160850048065186, 0.224483013153076, 0.163106918334961], [u'accuracy'], 0.966666638851, 0.0839553326368, [0.966666638851166, 0.908333361148834, 0.966666638851166], [0.124577566981316, 0.196399554610252, 0.0839553326368332], 0.966666638851, 0.074150800705, [0.966666638851166, 0.866666674613953, 0.966666638851166], [0.137340381741524, 0.232466518878937, 0.0741508007049561]),\n",
" (3, 1, u\"loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.14374303817749, 0.154287099838257, 0.17367696762085], [u'accuracy'], 0.966666638851, 0.0860244855285, [0.966666638851166, 0.841666638851166, 0.966666638851166], [0.0824147835373878, 0.337884455919266, 0.0860244855284691], 0.933333337307, 0.0704526007175, [0.966666638851166, 0.866666674613953, 0.933333337306976], [0.0690516456961632, 0.295713990926743, 0.0704526007175446]),\n",
" (2, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 0.7900390625, [0.155812978744507, 0.158360004425049, 0.159363031387329], [u'accuracy'], 0.833333313465, 0.344228476286, [0.666666686534882, 0.675000011920929, 0.833333313465118], [1.01126325130463, 1.33927237987518, 0.344228476285934], 0.800000011921, 0.305708706379, [0.699999988079071, 0.699999988079071, 0.800000011920929], [1.02303433418274, 1.36952638626099, 0.305708706378937]),\n",
" (11, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.187958955764771, 0.186024904251099, 0.501762866973877], [u'accuracy'], 0.725000023842, 0.423261642456, [0.658333361148834, 0.658333361148834, 0.725000023841858], [0.46866175532341, 0.445532470941544, 0.423261642456055], 0.699999988079, 0.378630697727, [0.699999988079071, 0.699999988079071, 0.699999988079071], [0.422465175390244, 0.398104608058929, 0.378630697727203]),\n",
" (5, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 0.7900390625, [0.176413059234619, 0.169157981872559, 0.15624213218689], [u'accuracy'], 0.675000011921, 0.470171242952, [0.641666650772095, 0.658333361148834, 0.675000011920929], [0.504463493824005, 0.486825525760651, 0.470171242952347], 0.699999988079, 0.436036229134, [0.699999988079071, 0.699999988079071, 0.699999988079071], [0.470719456672668, 0.452698260545731, 0.436036229133606]),\n",
" (12, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.164397954940796, 0.486438035964966, 0.192479133605957], [u'accuracy'], 0.550000011921, 0.975017726421, [0.508333325386047, 0.533333361148834, 0.550000011920929], [1.00239539146423, 0.986684203147888, 0.975017726421356], 0.466666668653, 0.981434583664, [0.5, 0.466666668653488, 0.466666668653488], [1.00223970413208, 0.989481270313263, 0.98143458366394]),\n",
" (8, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=8,epochs=1', u'madlib_keras', 1.2197265625, [0.467766046524048, 0.198179006576538, 0.186810970306396], [u'accuracy'], 0.341666668653, 1.10613942146, [0.316666662693024, 0.316666662693024, 0.341666668653488], [1.1275190114975, 1.10920584201813, 1.10613942146301], 0.300000011921, 1.10817503929, [0.400000005960464, 0.400000005960464, 0.300000011920929], [1.10070872306824, 1.09047472476959, 1.10817503929138]),\n",
" (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']\", u'batch_size=4,epochs=1', u'madlib_keras', 1.2197265625, [0.467660903930664, 0.195011138916016, 0.185934066772461], [u'accuracy'], 0.341666668653, 1.10524618626, [0.316666662693024, 0.341666668653488, 0.341666668653488], [1.10246300697327, 1.09976887702942, 1.10524618625641], 0.300000011921, 1.10809886456, [0.400000005960464, 0.300000011920929, 0.300000011920929], [1.09229254722595, 1.09808218479156, 1.10809886455536])]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot validation results:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"1000\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x130da4150>"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"df_summary = %sql SELECT * FROM iris_multi_model_summary;\n",
"df_summary = df_summary.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"iters = df_summary['metrics_iters'][0]\n",
"\n",
"for mst_key in df_results['mst_key']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" \n",
" ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')\n",
" ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')\n",
"\n",
"plt.legend()\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 1
}