blob: 4ae9eae0feb7aafc9b1f1acf78812595a38499fd [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Selection for Multilayer Perceptron Using Keras and MADlib\n",
"\n",
"E2E classification example using MADlib calling a Keras MLP for different hyperparameters and model architectures.\n",
"\n",
"Deep learning works best on very large datasets, but that is not convenient for a quick introduction to the syntax. So in this workbook we use the well known iris data set from https://archive.ics.uci.edu/ml/datasets/iris to help get you started. It is similar to the example in user docs http://madlib.apache.org/docs/latest/index.html\n",
"\n",
"For more realistic examples please refer to the deep learning notebooks at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts\n",
"\n",
"## Table of contents\n",
"\n",
"<a href=\"#class\">Classification</a>\n",
"\n",
"* <a href=\"#create_input_data\">1. Create input data</a>\n",
"\n",
"* <a href=\"#pp\">2. Call preprocessor for deep learning</a>\n",
"\n",
"* <a href=\"#load\">3. Define and load model architecture</a>\n",
"\n",
"* <a href=\"#def_mst\">4. Define and load model selection tuples</a>\n",
"\n",
"* <a href=\"#train\">5. Train</a>\n",
"\n",
"* <a href=\"#eval\">6. Evaluate</a>\n",
"\n",
"* <a href=\"#pred\">7. Predict</a>\n",
"\n",
"<a href=\"#class2\">Classification with Other Parameters</a>\n",
"\n",
"* <a href=\"#val_dataset\">1. Validation dataset</a>\n",
"\n",
"* <a href=\"#pred_prob\">2. Predict probabilities</a>\n",
"\n",
"* <a href=\"#warm_start\">3. Warm start</a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Greenplum Database 5.x on GCP - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.18.0-dev, git revision: rel/v1.17.0-89-g14a91ce, cmake configuration time: Fri Mar 5 23:08:38 UTC 2021, build type: release, build system: Linux-3.10.0-1160.11.1.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.18.0-dev, git revision: rel/v1.17.0-89-g14a91ce, cmake configuration time: Fri Mar 5 23:08:38 UTC 2021, build type: release, build system: Linux-3.10.0-1160.11.1.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class\"></a>\n",
"# Classification\n",
"\n",
"<a id=\"create_input_data\"></a>\n",
"# 1. Create input data\n",
"\n",
"Load iris data set."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"150 rows affected.\n",
"150 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>attributes</th>\n",
" <th>class_text</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>[Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>[Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>[Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>[Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>[Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>[Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>[Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>[Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>[Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>[Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>[Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>[Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>[Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>[Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37</td>\n",
" <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39</td>\n",
" <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42</td>\n",
" <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>46</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>47</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>48</td>\n",
" <td>[Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>[Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>[Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>[Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>52</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>54</td>\n",
" <td>[Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>[Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>58</td>\n",
" <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>63</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>65</td>\n",
" <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>66</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>68</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>[Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>[Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>71</td>\n",
" <td>[Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>73</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>74</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>[Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>[Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>[Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>78</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>79</td>\n",
" <td>[Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>[Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>81</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>83</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>[Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>85</td>\n",
" <td>[Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>86</td>\n",
" <td>[Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>87</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>88</td>\n",
" <td>[Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>[Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>91</td>\n",
" <td>[Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>93</td>\n",
" <td>[Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>[Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>95</td>\n",
" <td>[Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>96</td>\n",
" <td>[Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>[Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>[Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>[Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>[Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>104</td>\n",
" <td>[Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>105</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>106</td>\n",
" <td>[Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>[Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>108</td>\n",
" <td>[Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>109</td>\n",
" <td>[Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>110</td>\n",
" <td>[Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>111</td>\n",
" <td>[Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>[Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>[Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>[Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>[Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>118</td>\n",
" <td>[Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>[Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>[Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>122</td>\n",
" <td>[Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>[Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>124</td>\n",
" <td>[Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>126</td>\n",
" <td>[Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>[Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>128</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>129</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>130</td>\n",
" <td>[Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>131</td>\n",
" <td>[Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>132</td>\n",
" <td>[Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>133</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>134</td>\n",
" <td>[Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>135</td>\n",
" <td>[Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>[Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>137</td>\n",
" <td>[Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>138</td>\n",
" <td>[Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>139</td>\n",
" <td>[Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>140</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>141</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>142</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>143</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>[Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>[Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>150</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (2, [Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (3, [Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (4, [Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (5, [Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (6, [Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')], u'Iris-setosa'),\n",
" (7, [Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (8, [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (9, [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (10, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (11, [Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (12, [Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (13, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')], u'Iris-setosa'),\n",
" (14, [Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')], u'Iris-setosa'),\n",
" (15, [Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (16, [Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (17, [Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')], u'Iris-setosa'),\n",
" (18, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (19, [Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')], u'Iris-setosa'),\n",
" (20, [Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')], u'Iris-setosa'),\n",
" (21, [Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')], u'Iris-setosa'),\n",
" (22, [Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (23, [Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')], u'Iris-setosa'),\n",
" (24, [Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')], u'Iris-setosa'),\n",
" (25, [Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')], u'Iris-setosa'),\n",
" (26, [Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (27, [Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')], u'Iris-setosa'),\n",
" (28, [Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (29, [Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (30, [Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (31, [Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (32, [Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (33, [Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (34, [Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (35, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (36, [Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (37, [Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (38, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (39, [Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (40, [Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (41, [Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (42, [Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (43, [Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (44, [Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')], u'Iris-setosa'),\n",
" (45, [Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')], u'Iris-setosa'),\n",
" (46, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (47, [Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (48, [Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (49, [Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (50, [Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (51, [Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (52, [Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (53, [Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (54, [Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (55, [Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (56, [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (57, [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (58, [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (59, [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (60, [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (61, [Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (62, [Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (63, [Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (64, [Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (65, [Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (66, [Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (67, [Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (68, [Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (69, [Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (70, [Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (71, [Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')], u'Iris-versicolor'),\n",
" (72, [Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (73, [Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (74, [Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (75, [Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (76, [Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (77, [Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (78, [Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')], u'Iris-versicolor'),\n",
" (79, [Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (80, [Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (81, [Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (82, [Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (83, [Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (84, [Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (85, [Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (86, [Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (87, [Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (88, [Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (89, [Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (90, [Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (91, [Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (92, [Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (93, [Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (94, [Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (95, [Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (96, [Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (97, [Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (98, [Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (99, [Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (100, [Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (101, [Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')], u'Iris-virginica'),\n",
" (102, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (103, [Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')], u'Iris-virginica'),\n",
" (104, [Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')], u'Iris-virginica'),\n",
" (105, [Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')], u'Iris-virginica'),\n",
" (106, [Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (107, [Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')], u'Iris-virginica'),\n",
" (108, [Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')], u'Iris-virginica'),\n",
" (109, [Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (110, [Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')], u'Iris-virginica'),\n",
" (111, [Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')], u'Iris-virginica'),\n",
" (112, [Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')], u'Iris-virginica'),\n",
" (113, [Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')], u'Iris-virginica'),\n",
" (114, [Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')], u'Iris-virginica'),\n",
" (115, [Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')], u'Iris-virginica'),\n",
" (116, [Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')], u'Iris-virginica'),\n",
" (117, [Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (118, [Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')], u'Iris-virginica'),\n",
" (119, [Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (120, [Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')], u'Iris-virginica'),\n",
" (121, [Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')], u'Iris-virginica'),\n",
" (122, [Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')], u'Iris-virginica'),\n",
" (123, [Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')], u'Iris-virginica'),\n",
" (124, [Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (125, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')], u'Iris-virginica'),\n",
" (126, [Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')], u'Iris-virginica'),\n",
" (127, [Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (128, [Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (129, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (130, [Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')], u'Iris-virginica'),\n",
" (131, [Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (132, [Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')], u'Iris-virginica'),\n",
" (133, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')], u'Iris-virginica'),\n",
" (134, [Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')], u'Iris-virginica'),\n",
" (135, [Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')], u'Iris-virginica'),\n",
" (136, [Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (137, [Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (138, [Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (139, [Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (140, [Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')], u'Iris-virginica'),\n",
" (141, [Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (142, [Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (143, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (144, [Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (145, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')], u'Iris-virginica'),\n",
" (146, [Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')], u'Iris-virginica'),\n",
" (147, [Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')], u'Iris-virginica'),\n",
" (148, [Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')], u'Iris-virginica'),\n",
" (149, [Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')], u'Iris-virginica'),\n",
" (150, [Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')], u'Iris-virginica')]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS iris_data;\n",
"\n",
"CREATE TABLE iris_data(\n",
" id serial,\n",
" attributes numeric[],\n",
" class_text varchar\n",
");\n",
"\n",
"INSERT INTO iris_data(id, attributes, class_text) VALUES\n",
"(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),\n",
"(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),\n",
"(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),\n",
"(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),\n",
"(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),\n",
"(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),\n",
"(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),\n",
"(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),\n",
"(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),\n",
"(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),\n",
"(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),\n",
"(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),\n",
"(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),\n",
"(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),\n",
"(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),\n",
"(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),\n",
"(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),\n",
"(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),\n",
"(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),\n",
"(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),\n",
"(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),\n",
"(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),\n",
"(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),\n",
"(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),\n",
"(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),\n",
"(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),\n",
"(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),\n",
"(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),\n",
"(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),\n",
"(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),\n",
"(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),\n",
"(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),\n",
"(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),\n",
"(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),\n",
"(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),\n",
"(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),\n",
"(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),\n",
"(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),\n",
"(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),\n",
"(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),\n",
"(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),\n",
"(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),\n",
"(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),\n",
"(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),\n",
"(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),\n",
"(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),\n",
"(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),\n",
"(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),\n",
"(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),\n",
"(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),\n",
"(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),\n",
"(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),\n",
"(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),\n",
"(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),\n",
"(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),\n",
"(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),\n",
"(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),\n",
"(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),\n",
"(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),\n",
"(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),\n",
"(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),\n",
"(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),\n",
"(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),\n",
"(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),\n",
"(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),\n",
"(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),\n",
"(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),\n",
"(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),\n",
"(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),\n",
"(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),\n",
"(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),\n",
"(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),\n",
"(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),\n",
"(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),\n",
"(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),\n",
"(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),\n",
"(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),\n",
"(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),\n",
"(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),\n",
"(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),\n",
"(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),\n",
"(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),\n",
"(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),\n",
"(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),\n",
"(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),\n",
"(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),\n",
"(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),\n",
"(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),\n",
"(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),\n",
"(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),\n",
"(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),\n",
"(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),\n",
"(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),\n",
"(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),\n",
"(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),\n",
"(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),\n",
"(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),\n",
"(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),\n",
"(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),\n",
"(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),\n",
"(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),\n",
"(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),\n",
"(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),\n",
"(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),\n",
"(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),\n",
"(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),\n",
"(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),\n",
"(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),\n",
"(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),\n",
"(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),\n",
"(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),\n",
"(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),\n",
"(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),\n",
"(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),\n",
"(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),\n",
"(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),\n",
"(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),\n",
"(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),\n",
"(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),\n",
"(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),\n",
"(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),\n",
"(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),\n",
"(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),\n",
"(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),\n",
"(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),\n",
"(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),\n",
"(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),\n",
"(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),\n",
"(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),\n",
"(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),\n",
"(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),\n",
"(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),\n",
"(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),\n",
"(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),\n",
"(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),\n",
"(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),\n",
"(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),\n",
"(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),\n",
"(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),\n",
"(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),\n",
"(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');\n",
"\n",
"SELECT * FROM iris_data ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a test/validation dataset from the training data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(120L,)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train, iris_test;\n",
"\n",
"-- Set seed so results are reproducible\n",
"SELECT setseed(0);\n",
"\n",
"SELECT madlib.train_test_split('iris_data', -- Source table\n",
" 'iris', -- Output table root name\n",
" 0.8, -- Train proportion\n",
" NULL, -- Test proportion (0.2)\n",
" NULL, -- Strata definition\n",
" NULL, -- Output all columns\n",
" NULL, -- Sample without replacement\n",
" TRUE -- Separate output tables\n",
" );\n",
"\n",
"SELECT COUNT(*) FROM iris_train;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pp\"></a>\n",
"# 2. Call preprocessor for deep learning\n",
"Training dataset (uses training preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>attributes_shape</th>\n",
" <th>class_text_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[60, 4]</td>\n",
" <td>[60, 3]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[60, 4]</td>\n",
" <td>[60, 3]</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([60, 4], [60, 3], 0), ([60, 4], [60, 3], 1)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('iris_train', -- Source table\n",
" 'iris_train_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes' -- Independent variable\n",
" ); \n",
"\n",
"SELECT attributes_shape, class_text_shape, buffer_id FROM iris_train_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_text_class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train</td>\n",
" <td>iris_train_packed</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>60</td>\n",
" <td>1.0</td>\n",
" <td>[3]</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train', u'iris_train_packed', [u'class_text'], [u'attributes'], [u'character varying'], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 60, 1.0, [3], 'all_segments', 'all_segments')]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_train_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset (uses validation preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>attributes_shape</th>\n",
" <th>class_text_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[15, 4]</td>\n",
" <td>[15, 3]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[15, 4]</td>\n",
" <td>[15, 3]</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([15, 4], [15, 3], 0), ([15, 4], [15, 3], 1)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('iris_test', -- Source table\n",
" 'iris_test_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes', -- Independent variable\n",
" 'iris_train_packed' -- From training preprocessor step\n",
" ); \n",
"\n",
"SELECT attributes_shape, class_text_shape, buffer_id FROM iris_test_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_text_class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_test</td>\n",
" <td>iris_test_packed</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>15</td>\n",
" <td>1.0</td>\n",
" <td>[3]</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_test', u'iris_test_packed', [u'class_text'], [u'attributes'], [u'character varying'], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 15, 1.0, [3], 'all_segments', 'all_segments')]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_test_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load\"></a>\n",
"# 3. Define and load model architecture\n",
"Import Keras libraries"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow import keras\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define model architecture with 1 hidden layer:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/fmcquillan/Library/Python/2.7/lib/python/site-packages/tensorflow/python/ops/init_ops.py:1251: calling __init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 193\n",
"Trainable params: 193\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model1 = Sequential()\n",
"model1.add(Dense(10, activation='relu', input_shape=(4,)))\n",
"model1.add(Dense(10, activation='relu'))\n",
"model1.add(Dense(3, activation='softmax'))\n",
" \n",
"model1.summary()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.2.4-tf\", \"config\": {\"layers\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"name\": \"sequential\"}, \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model1.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define model architecture with 2 hidden layers:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_3 (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 303\n",
"Trainable params: 303\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model2 = Sequential()\n",
"model2.add(Dense(10, activation='relu', input_shape=(4,)))\n",
"model2.add(Dense(10, activation='relu'))\n",
"model2.add(Dense(10, activation='relu'))\n",
"model2.add(Dense(3, activation='softmax'))\n",
" \n",
"model2.summary()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.2.4-tf\", \"config\": {\"layers\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_6\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"name\": \"sequential_1\"}, \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>model_arch</th>\n",
" <th>model_weights</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>__internal_madlib_id__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>None</td>\n",
" <td>Sophie</td>\n",
" <td>MLP with 1 hidden layer</td>\n",
" <td>__madlib_temp_4017958_1614991901_4240024__</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_4', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_5', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_6', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_7', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>None</td>\n",
" <td>Maria</td>\n",
" <td>MLP with 2 hidden layers</td>\n",
" <td>__madlib_temp_28416680_1614991901_72274844__</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u' ... (1340 characters truncated) ... s_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, None, u'Sophie', u'MLP with 1 hidden layer', u'__madlib_temp_4017958_1614991901_4240024__'),\n",
" (2, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u' ... (1835 characters truncated) ... s_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, None, u'Maria', u'MLP with 2 hidden layers', u'__madlib_temp_28416680_1614991901_72274844__')]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS model_arch_library;\n",
"\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table,\n",
" \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Sophie', -- Name\n",
" 'MLP with 1 hidden layer' -- Descr\n",
");\n",
"\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table,\n",
" \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_6\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_7\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Maria', -- Name\n",
" 'MLP with 2 hidden layers' -- Descr\n",
");\n",
"\n",
"SELECT * FROM model_arch_library ORDER BY model_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"def_mst\"></a>\n",
"# 4. Define and load model selection tuples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate model configurations using grid search. The output table for grid search contains the unique combinations of model architectures, compile and fit parameters."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (2, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8'),\n",
" (3, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (4, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8'),\n",
" (5, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (6, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8'),\n",
" (7, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (8, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8'),\n",
" (9, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (10, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8'),\n",
" (11, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4'),\n",
" (12, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8')]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS mst_table, mst_table_summary;\n",
"\n",
"SELECT madlib.generate_model_configs(\n",
" 'model_arch_library', -- model architecture table\n",
" 'mst_table', -- model selection table output\n",
" ARRAY[1,2], -- model ids from model architecture table\n",
" $$\n",
" {'loss': ['categorical_crossentropy'],\n",
" 'optimizer_params_list': [ {'optimizer': ['Adam'], 'lr': [0.001, 0.01, 0.1]} ],\n",
" 'metrics': ['accuracy']}\n",
" $$, -- compile_param_grid\n",
" $$\n",
" { 'batch_size': [4, 8],\n",
" 'epochs': [1]\n",
" }\n",
" $$, -- fit_param_grid\n",
" 'grid' -- search_type\n",
" );\n",
"\n",
"SELECT * FROM mst_table ORDER BY mst_key;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is the name of the model architecture table that corresponds to the model selection table:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_arch_table</th>\n",
" <th>object_table</th>\n",
" </tr>\n",
" <tr>\n",
" <td>model_arch_library</td>\n",
" <td>None</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'model_arch_library', None)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM mst_table_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"train\"></a>\n",
"# 5. Train\n",
"Train multiple models:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
"\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 10, -- num_iterations\n",
" FALSE -- use gpus\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_selection_table</th>\n",
" <th>object_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_text_class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>None</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>mst_table</td>\n",
" <td>None</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" <td>False</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2021-03-06 00:51:48.452654</td>\n",
" <td>2021-03-06 00:53:20.221035</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[1]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[10]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', None, u'iris_multi_model', u'iris_multi_model_info', [u'class_text'], [u'attributes'], u'model_arch_library', u'mst_table', None, 10, 10, False, None, None, datetime.datetime(2021, 3, 6, 0, 51, 48, 452654), datetime.datetime(2021, 3, 6, 0, 53, 20, 221035), u'1.18.0-dev', [1], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], [u'character varying'], 1.0, [10])]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View results for each model:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[90.2427790164948]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.983333349228</td>\n",
" <td>0.201789721847</td>\n",
" <td>[0.983333349227905]</td>\n",
" <td>[0.201789721846581]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[88.9964590072632]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.134730249643</td>\n",
" <td>[0.933333337306976]</td>\n",
" <td>[0.134730249643326]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[88.7690601348877]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.402144879103</td>\n",
" <td>[0.933333337306976]</td>\n",
" <td>[0.402144879102707]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[90.9196391105652]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.416792035103</td>\n",
" <td>[0.933333337306976]</td>\n",
" <td>[0.416792035102844]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[89.534707069397]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.908333361149</td>\n",
" <td>0.19042557478</td>\n",
" <td>[0.908333361148834]</td>\n",
" <td>[0.19042557477951]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[89.273796081543]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.181902274489</td>\n",
" <td>[0.899999976158142]</td>\n",
" <td>[0.181902274489403]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[90.4800100326538]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.824999988079</td>\n",
" <td>0.303107827902</td>\n",
" <td>[0.824999988079071]</td>\n",
" <td>[0.30310782790184]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[89.7936120033264]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.808333337307</td>\n",
" <td>0.300039559603</td>\n",
" <td>[0.808333337306976]</td>\n",
" <td>[0.300039559602737]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[90.0158791542053]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.658333361149</td>\n",
" <td>0.869387447834</td>\n",
" <td>[0.658333361148834]</td>\n",
" <td>[0.869387447834015]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[91.1929490566254]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.558333337307</td>\n",
" <td>0.84612262249</td>\n",
" <td>[0.558333337306976]</td>\n",
" <td>[0.846122622489929]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[91.7660541534424]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.341666668653</td>\n",
" <td>1.10138702393</td>\n",
" <td>[0.341666668653488]</td>\n",
" <td>[1.10138702392578]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[91.5026919841766]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.341666668653</td>\n",
" <td>1.10163521767</td>\n",
" <td>[0.341666668653488]</td>\n",
" <td>[1.10163521766663]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [90.2427790164948], [u'accuracy'], u'categorical_crossentropy', 0.983333349227905, 0.201789721846581, [0.983333349227905], [0.201789721846581], None, None, None, None),\n",
" (3, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [88.9964590072632], [u'accuracy'], u'categorical_crossentropy', 0.933333337306976, 0.134730249643326, [0.933333337306976], [0.134730249643326], None, None, None, None),\n",
" (7, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [88.7690601348877], [u'accuracy'], u'categorical_crossentropy', 0.933333337306976, 0.402144879102707, [0.933333337306976], [0.402144879102707], None, None, None, None),\n",
" (1, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [90.9196391105652], [u'accuracy'], u'categorical_crossentropy', 0.933333337306976, 0.416792035102844, [0.933333337306976], [0.416792035102844], None, None, None, None),\n",
" (10, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [89.534707069397], [u'accuracy'], u'categorical_crossentropy', 0.908333361148834, 0.19042557477951, [0.908333361148834], [0.19042557477951], None, None, None, None),\n",
" (9, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [89.273796081543], [u'accuracy'], u'categorical_crossentropy', 0.899999976158142, 0.181902274489403, [0.899999976158142], [0.181902274489403], None, None, None, None),\n",
" (4, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [90.4800100326538], [u'accuracy'], u'categorical_crossentropy', 0.824999988079071, 0.30310782790184, [0.824999988079071], [0.30310782790184], None, None, None, None),\n",
" (12, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [89.7936120033264], [u'accuracy'], u'categorical_crossentropy', 0.808333337306976, 0.300039559602737, [0.808333337306976], [0.300039559602737], None, None, None, None),\n",
" (2, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [90.0158791542053], [u'accuracy'], u'categorical_crossentropy', 0.658333361148834, 0.869387447834015, [0.658333361148834], [0.869387447834015], None, None, None, None),\n",
" (8, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [91.1929490566254], [u'accuracy'], u'categorical_crossentropy', 0.558333337306976, 0.846122622489929, [0.558333337306976], [0.846122622489929], None, None, None, None),\n",
" (11, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [91.7660541534424], [u'accuracy'], u'categorical_crossentropy', 0.341666668653488, 1.10138702392578, [0.341666668653488], [1.10138702392578], None, None, None, None),\n",
" (5, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [91.5026919841766], [u'accuracy'], u'categorical_crossentropy', 0.341666668653488, 1.10163521766663, [0.341666668653488], [1.10163521766663], None, None, None, None)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY training_metrics_final DESC, training_loss_final;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"eval\"></a>\n",
"# 6. Evaluate\n",
"\n",
"Now run evaluate using model we built above:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>loss</th>\n",
" <th>metric</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.194916069508</td>\n",
" <td>0.899999976158</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.194916069507599, 0.899999976158142, [u'accuracy'], u'categorical_crossentropy')]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_validate;\n",
"SELECT madlib.madlib_keras_evaluate('iris_multi_model', -- model\n",
" 'iris_test_packed', -- test table\n",
" 'iris_validate', -- output table\n",
" NULL, -- use gpus\n",
" 9 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_validate;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred\"></a>\n",
"# 7. Predict\n",
"\n",
"Now predict using model we built. We will use the validation data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"30 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>class_name</th>\n",
" <th>class_value</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999999</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999999</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999999</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999999</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.99069124</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9864196</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9983382</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9991603</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9974559</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.60661113</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9940832</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9987955</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.7598468</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.8414144</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.715776</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.9163472</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.5081183</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.85080105</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.9842195</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6804195</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.81555897</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.92707217</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7158722</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.55272627</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7662018</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3, u'class_text', u'Iris-setosa', 1.0),\n",
" (10, u'class_text', u'Iris-setosa', 0.9999999),\n",
" (12, u'class_text', u'Iris-setosa', 1.0),\n",
" (14, u'class_text', u'Iris-setosa', 0.9999999),\n",
" (18, u'class_text', u'Iris-setosa', 1.0),\n",
" (20, u'class_text', u'Iris-setosa', 1.0),\n",
" (30, u'class_text', u'Iris-setosa', 0.9999999),\n",
" (31, u'class_text', u'Iris-setosa', 0.9999999),\n",
" (49, u'class_text', u'Iris-setosa', 1.0),\n",
" (55, u'class_text', u'Iris-versicolor', 0.99069124),\n",
" (64, u'class_text', u'Iris-versicolor', 0.9864196),\n",
" (70, u'class_text', u'Iris-versicolor', 0.9983382),\n",
" (76, u'class_text', u'Iris-versicolor', 0.9991603),\n",
" (82, u'class_text', u'Iris-versicolor', 0.9974559),\n",
" (84, u'class_text', u'Iris-versicolor', 0.60661113),\n",
" (92, u'class_text', u'Iris-versicolor', 0.9940832),\n",
" (98, u'class_text', u'Iris-versicolor', 0.9987955),\n",
" (99, u'class_text', u'Iris-versicolor', 0.7598468),\n",
" (102, u'class_text', u'Iris-virginica', 0.8414144),\n",
" (107, u'class_text', u'Iris-virginica', 0.715776),\n",
" (114, u'class_text', u'Iris-virginica', 0.9163472),\n",
" (117, u'class_text', u'Iris-versicolor', 0.5081183),\n",
" (121, u'class_text', u'Iris-virginica', 0.85080105),\n",
" (123, u'class_text', u'Iris-virginica', 0.9842195),\n",
" (125, u'class_text', u'Iris-virginica', 0.6804195),\n",
" (127, u'class_text', u'Iris-versicolor', 0.81555897),\n",
" (145, u'class_text', u'Iris-virginica', 0.92707217),\n",
" (147, u'class_text', u'Iris-virginica', 0.7158722),\n",
" (148, u'class_text', u'Iris-versicolor', 0.55272627),\n",
" (149, u'class_text', u'Iris-virginica', 0.7662018)]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_multi_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict', -- output table\n",
" 'response', -- prediction type\n",
" FALSE, -- use gpus\n",
" 9 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3L,)]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id) \n",
"WHERE iris_predict.class_value != iris_test.class_text;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Percent missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>test_accuracy_percent</th>\n",
" </tr>\n",
" <tr>\n",
" <td>90.00</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('90.00'),)]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from\n",
" (select iris_test.class_text as actual, iris_predict.class_value as estimated\n",
" from iris_predict inner join iris_test\n",
" on iris_test.id=iris_predict.id) q\n",
"WHERE q.actual=q.estimated;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class2\"></a>\n",
"# Classification with Other Parameters\n",
"\n",
"<a id=\"val_dataset\"></a>\n",
"# 1. Validation dataset\n",
"\n",
"Now use a validation dataset and compute metrics every 2nd iteration using the 'metrics_compute_frequency' parameter. This can help reduce run time if you do not need metrics computed at every iteration."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
"\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 10, -- num_iterations\n",
" FALSE, -- use gpus\n",
" 'iris_test_packed', -- validation dataset\n",
" 3, -- metrics compute frequency\n",
" FALSE, -- warm start\n",
" 'Sophie L.', -- name\n",
" 'Model selection for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_selection_table</th>\n",
" <th>object_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_text_class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_test_packed</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>mst_table</td>\n",
" <td>None</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>Sophie L.</td>\n",
" <td>Model selection for iris dataset</td>\n",
" <td>2021-03-06 00:53:31.218406</td>\n",
" <td>2021-03-06 00:55:25.621208</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[1]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[3, 6, 9, 10]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_test_packed', u'iris_multi_model', u'iris_multi_model_info', [u'class_text'], [u'attributes'], u'model_arch_library', u'mst_table', None, 10, 3, False, u'Sophie L.', u'Model selection for iris dataset', datetime.datetime(2021, 3, 6, 0, 53, 31, 218406), datetime.datetime(2021, 3, 6, 0, 55, 25, 621208), u'1.18.0-dev', [1], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], [u'character varying'], 1.0, [3, 6, 9, 10])]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View performance of each model:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[31.5490398406982, 64.223620891571, 97.8899219036102, 113.156138896942]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.991666674614</td>\n",
" <td>0.177691921592</td>\n",
" <td>[0.824999988079071, 0.975000023841858, 0.933333337306976, 0.991666674613953]</td>\n",
" <td>[0.508709609508514, 0.290052831172943, 0.217903628945351, 0.177691921591759]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.20564225316</td>\n",
" <td>[0.833333313465118, 0.966666638851166, 0.933333337306976, 0.966666638851166]</td>\n",
" <td>[0.516587793827057, 0.316147029399872, 0.228292018175125, 0.205642253160477]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[30.4000718593597, 62.9767029285431, 96.690801858902, 112.145288944244]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.908333361149</td>\n",
" <td>0.203085869551</td>\n",
" <td>[0.933333337306976, 0.808333337306976, 0.958333313465118, 0.908333361148834]</td>\n",
" <td>[0.372362315654755, 0.304766088724136, 0.11820487678051, 0.203085869550705]</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.190864190459</td>\n",
" <td>[0.966666638851166, 0.833333313465118, 0.966666638851166, 0.933333337306976]</td>\n",
" <td>[0.347199022769928, 0.290798246860504, 0.110275268554688, 0.190864190459251]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[30.875373840332, 63.4593389034271, 97.1958589553833, 112.702126979828]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.883333325386</td>\n",
" <td>0.692279815674</td>\n",
" <td>[0.533333361148834, 0.616666674613953, 0.875, 0.883333325386047]</td>\n",
" <td>[1.08197057247162, 0.851473987102509, 0.729827761650085, 0.692279815673828]</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.674779772758</td>\n",
" <td>[0.600000023841858, 0.666666686534882, 0.899999976158142, 0.899999976158142]</td>\n",
" <td>[1.05298256874084, 0.817528009414673, 0.710631787776947, 0.674779772758484]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[29.8903229236603, 62.4677069187164, 96.1764039993286, 111.539803981781]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.925000011921</td>\n",
" <td>0.176520362496</td>\n",
" <td>[0.833333313465118, 0.925000011920929, 0.774999976158142, 0.925000011920929]</td>\n",
" <td>[0.324734181165695, 0.182637020945549, 0.468331128358841, 0.176520362496376]</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.2585529387</td>\n",
" <td>[0.866666674613953, 0.866666674613953, 0.866666674613953, 0.899999976158142]</td>\n",
" <td>[0.341204434633255, 0.261798053979874, 0.45467621088028, 0.258552938699722]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[31.7836039066315, 64.4592599868774, 98.1328208446503, 113.377946853638]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.891666650772</td>\n",
" <td>0.797108471394</td>\n",
" <td>[0.341666668653488, 0.491666674613953, 0.916666686534882, 0.891666650772095]</td>\n",
" <td>[1.09786474704742, 0.967048287391663, 0.838281869888306, 0.797108471393585]</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.800795376301</td>\n",
" <td>[0.300000011920929, 0.433333337306976, 0.933333337306976, 0.899999976158142]</td>\n",
" <td>[1.07609903812408, 0.962578594684601, 0.834975183010101, 0.800795376300812]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[30.1456639766693, 62.722916841507, 96.4333670139313, 111.892151832581]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.816666662693</td>\n",
" <td>0.734887838364</td>\n",
" <td>[0.850000023841858, 0.958333313465118, 0.966666638851166, 0.816666662693024]</td>\n",
" <td>[0.335647404193878, 0.0894104242324829, 0.0672163665294647, 0.734887838363647]</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.665323019028</td>\n",
" <td>[0.866666674613953, 0.966666638851166, 0.966666638851166, 0.866666674613953]</td>\n",
" <td>[0.320426166057587, 0.154994085431099, 0.204012081027031, 0.66532301902771]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[32.0452349185944, 64.7241299152374, 98.4015560150146, 113.899842977524]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.791666686535</td>\n",
" <td>0.772948563099</td>\n",
" <td>[0.316666662693024, 0.349999994039536, 0.725000023841858, 0.791666686534882]</td>\n",
" <td>[1.01266825199127, 0.905348658561707, 0.807280421257019, 0.772948563098907]</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.740880072117</td>\n",
" <td>[0.400000005960464, 0.466666668653488, 0.800000011920929, 0.866666674613953]</td>\n",
" <td>[0.964996755123138, 0.868514597415924, 0.771895349025726, 0.740880072116852]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[30.6602540016174, 63.2428169250488, 96.9531948566437, 112.484740972519]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.691666662693</td>\n",
" <td>0.501820206642</td>\n",
" <td>[0.658333361148834, 0.658333361148834, 0.658333361148834, 0.691666662693024]</td>\n",
" <td>[0.654709756374359, 0.581917643547058, 1.33844769001007, 0.501820206642151]</td>\n",
" <td>0.766666650772</td>\n",
" <td>0.457984447479</td>\n",
" <td>[0.699999988079071, 0.699999988079071, 0.699999988079071, 0.766666650772095]</td>\n",
" <td>[0.592061340808868, 0.525563180446625, 1.17788350582123, 0.457984447479248]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[31.0910878181458, 63.7646949291229, 97.4185988903046, 112.939773797989]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.666666686535</td>\n",
" <td>0.50052946806</td>\n",
" <td>[0.433333337306976, 0.641666650772095, 0.649999976158142, 0.666666686534882]</td>\n",
" <td>[0.850135624408722, 0.611121952533722, 0.509139358997345, 0.50052946805954]</td>\n",
" <td>0.733333349228</td>\n",
" <td>0.459399551153</td>\n",
" <td>[0.466666668653488, 0.699999988079071, 0.699999988079071, 0.733333349227905]</td>\n",
" <td>[0.802468597888947, 0.571285247802734, 0.492577910423279, 0.459399551153183]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[29.6670269966125, 62.2440509796143, 95.9554150104523, 111.311369895935]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.733333349228</td>\n",
" <td>0.821944594383</td>\n",
" <td>[0.341666668653488, 0.341666668653488, 0.658333361148834, 0.733333349227905]</td>\n",
" <td>[1.06431686878204, 0.996406197547913, 0.869706034660339, 0.82194459438324]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.852133929729</td>\n",
" <td>[0.300000011920929, 0.300000011920929, 0.699999988079071, 0.699999988079071]</td>\n",
" <td>[1.09268116950989, 1.01670277118683, 0.891825795173645, 0.852133929729462]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[32.5322558879852, 65.2217888832092, 98.9477097988129, 114.400418996811]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.683333337307</td>\n",
" <td>0.455871999264</td>\n",
" <td>[0.725000023841858, 0.683333337306976, 0.683333337306976, 0.683333337306976]</td>\n",
" <td>[0.383917421102524, 0.457853585481644, 0.455943495035172, 0.455871999263763]</td>\n",
" <td>0.600000023842</td>\n",
" <td>0.488439053297</td>\n",
" <td>[0.800000011920929, 0.600000023841858, 0.600000023841858, 0.600000023841858]</td>\n",
" <td>[0.388951361179352, 0.50080794095993, 0.487448841333389, 0.488439053297043]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[32.2720308303833, 64.9502189159393, 98.6836059093475, 114.134181976318]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.675000011921</td>\n",
" <td>0.452209770679</td>\n",
" <td>[0.683333337306976, 0.675000011920929, 0.683333337306976, 0.675000011920929]</td>\n",
" <td>[0.492754250764847, 0.469423890113831, 0.571796059608459, 0.452209770679474]</td>\n",
" <td>0.600000023842</td>\n",
" <td>0.464268505573</td>\n",
" <td>[0.733333349227905, 0.766666650772095, 0.600000023841858, 0.600000023841858]</td>\n",
" <td>[0.438488334417343, 0.390993624925613, 0.690678656101227, 0.464268505573273]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(4, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [31.5490398406982, 64.223620891571, 97.8899219036102, 113.156138896942], [u'accuracy'], u'categorical_crossentropy', 0.991666674613953, 0.177691921591759, [0.824999988079071, 0.975000023841858, 0.933333337306976, 0.991666674613953], [0.508709609508514, 0.290052831172943, 0.217903628945351, 0.177691921591759], 0.966666638851166, 0.205642253160477, [0.833333313465118, 0.966666638851166, 0.933333337306976, 0.966666638851166], [0.516587793827057, 0.316147029399872, 0.228292018175125, 0.205642253160477]),\n",
" (10, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [30.4000718593597, 62.9767029285431, 96.690801858902, 112.145288944244], [u'accuracy'], u'categorical_crossentropy', 0.908333361148834, 0.203085869550705, [0.933333337306976, 0.808333337306976, 0.958333313465118, 0.908333361148834], [0.372362315654755, 0.304766088724136, 0.11820487678051, 0.203085869550705], 0.933333337306976, 0.190864190459251, [0.966666638851166, 0.833333313465118, 0.966666638851166, 0.933333337306976], [0.347199022769928, 0.290798246860504, 0.110275268554688, 0.190864190459251]),\n",
" (2, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [30.875373840332, 63.4593389034271, 97.1958589553833, 112.702126979828], [u'accuracy'], u'categorical_crossentropy', 0.883333325386047, 0.692279815673828, [0.533333361148834, 0.616666674613953, 0.875, 0.883333325386047], [1.08197057247162, 0.851473987102509, 0.729827761650085, 0.692279815673828], 0.899999976158142, 0.674779772758484, [0.600000023841858, 0.666666686534882, 0.899999976158142, 0.899999976158142], [1.05298256874084, 0.817528009414673, 0.710631787776947, 0.674779772758484]),\n",
" (3, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [29.8903229236603, 62.4677069187164, 96.1764039993286, 111.539803981781], [u'accuracy'], u'categorical_crossentropy', 0.925000011920929, 0.176520362496376, [0.833333313465118, 0.925000011920929, 0.774999976158142, 0.925000011920929], [0.324734181165695, 0.182637020945549, 0.468331128358841, 0.176520362496376], 0.899999976158142, 0.258552938699722, [0.866666674613953, 0.866666674613953, 0.866666674613953, 0.899999976158142], [0.341204434633255, 0.261798053979874, 0.45467621088028, 0.258552938699722]),\n",
" (1, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [31.7836039066315, 64.4592599868774, 98.1328208446503, 113.377946853638], [u'accuracy'], u'categorical_crossentropy', 0.891666650772095, 0.797108471393585, [0.341666668653488, 0.491666674613953, 0.916666686534882, 0.891666650772095], [1.09786474704742, 0.967048287391663, 0.838281869888306, 0.797108471393585], 0.899999976158142, 0.800795376300812, [0.300000011920929, 0.433333337306976, 0.933333337306976, 0.899999976158142], [1.07609903812408, 0.962578594684601, 0.834975183010101, 0.800795376300812]),\n",
" (9, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [30.1456639766693, 62.722916841507, 96.4333670139313, 111.892151832581], [u'accuracy'], u'categorical_crossentropy', 0.816666662693024, 0.734887838363647, [0.850000023841858, 0.958333313465118, 0.966666638851166, 0.816666662693024], [0.335647404193878, 0.0894104242324829, 0.0672163665294647, 0.734887838363647], 0.866666674613953, 0.66532301902771, [0.866666674613953, 0.966666638851166, 0.966666638851166, 0.866666674613953], [0.320426166057587, 0.154994085431099, 0.204012081027031, 0.66532301902771]),\n",
" (8, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [32.0452349185944, 64.7241299152374, 98.4015560150146, 113.899842977524], [u'accuracy'], u'categorical_crossentropy', 0.791666686534882, 0.772948563098907, [0.316666662693024, 0.349999994039536, 0.725000023841858, 0.791666686534882], [1.01266825199127, 0.905348658561707, 0.807280421257019, 0.772948563098907], 0.866666674613953, 0.740880072116852, [0.400000005960464, 0.466666668653488, 0.800000011920929, 0.866666674613953], [0.964996755123138, 0.868514597415924, 0.771895349025726, 0.740880072116852]),\n",
" (12, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [30.6602540016174, 63.2428169250488, 96.9531948566437, 112.484740972519], [u'accuracy'], u'categorical_crossentropy', 0.691666662693024, 0.501820206642151, [0.658333361148834, 0.658333361148834, 0.658333361148834, 0.691666662693024], [0.654709756374359, 0.581917643547058, 1.33844769001007, 0.501820206642151], 0.766666650772095, 0.457984447479248, [0.699999988079071, 0.699999988079071, 0.699999988079071, 0.766666650772095], [0.592061340808868, 0.525563180446625, 1.17788350582123, 0.457984447479248]),\n",
" (6, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [31.0910878181458, 63.7646949291229, 97.4185988903046, 112.939773797989], [u'accuracy'], u'categorical_crossentropy', 0.666666686534882, 0.50052946805954, [0.433333337306976, 0.641666650772095, 0.649999976158142, 0.666666686534882], [0.850135624408722, 0.611121952533722, 0.509139358997345, 0.50052946805954], 0.733333349227905, 0.459399551153183, [0.466666668653488, 0.699999988079071, 0.699999988079071, 0.733333349227905], [0.802468597888947, 0.571285247802734, 0.492577910423279, 0.459399551153183]),\n",
" (7, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [29.6670269966125, 62.2440509796143, 95.9554150104523, 111.311369895935], [u'accuracy'], u'categorical_crossentropy', 0.733333349227905, 0.82194459438324, [0.341666668653488, 0.341666668653488, 0.658333361148834, 0.733333349227905], [1.06431686878204, 0.996406197547913, 0.869706034660339, 0.82194459438324], 0.699999988079071, 0.852133929729462, [0.300000011920929, 0.300000011920929, 0.699999988079071, 0.699999988079071], [1.09268116950989, 1.01670277118683, 0.891825795173645, 0.852133929729462]),\n",
" (11, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [32.5322558879852, 65.2217888832092, 98.9477097988129, 114.400418996811], [u'accuracy'], u'categorical_crossentropy', 0.683333337306976, 0.455871999263763, [0.725000023841858, 0.683333337306976, 0.683333337306976, 0.683333337306976], [0.383917421102524, 0.457853585481644, 0.455943495035172, 0.455871999263763], 0.600000023841858, 0.488439053297043, [0.800000011920929, 0.600000023841858, 0.600000023841858, 0.600000023841858], [0.388951361179352, 0.50080794095993, 0.487448841333389, 0.488439053297043]),\n",
" (5, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [32.2720308303833, 64.9502189159393, 98.6836059093475, 114.134181976318], [u'accuracy'], u'categorical_crossentropy', 0.675000011920929, 0.452209770679474, [0.683333337306976, 0.675000011920929, 0.683333337306976, 0.675000011920929], [0.492754250764847, 0.469423890113831, 0.571796059608459, 0.452209770679474], 0.600000023841858, 0.464268505573273, [0.733333349227905, 0.766666650772095, 0.600000023841858, 0.600000023841858], [0.438488334417343, 0.390993624925613, 0.690678656101227, 0.464268505573273])]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot validation results"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"from collections import defaultdict\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"sns.set_palette(sns.color_palette(\"hls\", 20))\n",
"plt.rcParams.update({'font.size': 12})\n",
"pd.set_option('display.max_colwidth', -1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"720\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"df_summary = %sql SELECT * FROM iris_multi_model_summary;\n",
"df_summary = df_summary.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"iters = df_summary['metrics_iters'][0]\n",
"\n",
"for mst_key in df_results['mst_key']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" \n",
" ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')\n",
" ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')\n",
"\n",
"plt.legend();\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred_prob\"></a>\n",
"# 2. Predict probabilities\n",
"\n",
"Predict with probabilities for each class:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"90 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>class_name</th>\n",
" <th>class_value</th>\n",
" <th>prob</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999932</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>6.7611923e-06</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.2535056e-10</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999808</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>1.9209425e-05</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>4.433645e-10</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99998367</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>1.6334934e-05</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>4.3492965e-10</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999931</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>6.9504345e-06</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.9190094e-10</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99999726</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>2.719827e-06</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>2.4018267e-11</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9999982</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>1.8036015e-06</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.515534e-11</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99996376</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>3.623055e-05</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.4014193e-09</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99995685</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>4.3105167e-05</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.541236e-09</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99999833</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>1.6733742e-06</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.0720992e-11</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.97456545</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.025385397</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>4.912654e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.8837083</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.11627731</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.4444132e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9832433</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.016161945</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0005947249</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9934144</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.006202936</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00038262276</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9880006</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.01050145</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0014980072</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.743757</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.25624287</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.1804799e-07</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9489498</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.050999135</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>5.1051586e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.9882598</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.011410431</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00032975432</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.7122672</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.2864844</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.0012483773</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.8344315</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.16556835</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>4.9313943e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7617606</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.23823881</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>6.2156596e-07</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.85601324</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.1439867</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.4068247e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.76065344</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.23934652</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>4.0775706e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.65924823</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.34075174</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.7877243e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.968423</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.031577036</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>1.5606285e-11</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.72842705</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.2715729</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.7875385e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.8053533</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.19464317</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.5179064e-06</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7297866</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.2702134</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>2.8784607e-08</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5341273</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4658725</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>2.3799986e-07</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.6266347</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.3733647</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>5.7692125e-07</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5517554</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4482443</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.1108453e-07</td>\n",
" <td>3</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3, u'class_text', u'Iris-setosa', 0.9999932, 1),\n",
" (3, u'class_text', u'Iris-versicolor', 6.7611923e-06, 2),\n",
" (3, u'class_text', u'Iris-virginica', 1.2535056e-10, 3),\n",
" (10, u'class_text', u'Iris-setosa', 0.9999808, 1),\n",
" (10, u'class_text', u'Iris-versicolor', 1.9209425e-05, 2),\n",
" (10, u'class_text', u'Iris-virginica', 4.433645e-10, 3),\n",
" (12, u'class_text', u'Iris-setosa', 0.99998367, 1),\n",
" (12, u'class_text', u'Iris-versicolor', 1.6334934e-05, 2),\n",
" (12, u'class_text', u'Iris-virginica', 4.3492965e-10, 3),\n",
" (14, u'class_text', u'Iris-setosa', 0.9999931, 1),\n",
" (14, u'class_text', u'Iris-versicolor', 6.9504345e-06, 2),\n",
" (14, u'class_text', u'Iris-virginica', 1.9190094e-10, 3),\n",
" (18, u'class_text', u'Iris-setosa', 0.99999726, 1),\n",
" (18, u'class_text', u'Iris-versicolor', 2.719827e-06, 2),\n",
" (18, u'class_text', u'Iris-virginica', 2.4018267e-11, 3),\n",
" (20, u'class_text', u'Iris-setosa', 0.9999982, 1),\n",
" (20, u'class_text', u'Iris-versicolor', 1.8036015e-06, 2),\n",
" (20, u'class_text', u'Iris-virginica', 1.515534e-11, 3),\n",
" (30, u'class_text', u'Iris-setosa', 0.99996376, 1),\n",
" (30, u'class_text', u'Iris-versicolor', 3.623055e-05, 2),\n",
" (30, u'class_text', u'Iris-virginica', 1.4014193e-09, 3),\n",
" (31, u'class_text', u'Iris-setosa', 0.99995685, 1),\n",
" (31, u'class_text', u'Iris-versicolor', 4.3105167e-05, 2),\n",
" (31, u'class_text', u'Iris-virginica', 1.541236e-09, 3),\n",
" (49, u'class_text', u'Iris-setosa', 0.99999833, 1),\n",
" (49, u'class_text', u'Iris-versicolor', 1.6733742e-06, 2),\n",
" (49, u'class_text', u'Iris-virginica', 1.0720992e-11, 3),\n",
" (55, u'class_text', u'Iris-versicolor', 0.97456545, 1),\n",
" (55, u'class_text', u'Iris-virginica', 0.025385397, 2),\n",
" (55, u'class_text', u'Iris-setosa', 4.912654e-05, 3),\n",
" (64, u'class_text', u'Iris-versicolor', 0.8837083, 1),\n",
" (64, u'class_text', u'Iris-virginica', 0.11627731, 2),\n",
" (64, u'class_text', u'Iris-setosa', 1.4444132e-05, 3),\n",
" (70, u'class_text', u'Iris-versicolor', 0.9832433, 1),\n",
" (70, u'class_text', u'Iris-virginica', 0.016161945, 2),\n",
" (70, u'class_text', u'Iris-setosa', 0.0005947249, 3),\n",
" (76, u'class_text', u'Iris-versicolor', 0.9934144, 1),\n",
" (76, u'class_text', u'Iris-virginica', 0.006202936, 2),\n",
" (76, u'class_text', u'Iris-setosa', 0.00038262276, 3),\n",
" (82, u'class_text', u'Iris-versicolor', 0.9880006, 1),\n",
" (82, u'class_text', u'Iris-virginica', 0.01050145, 2),\n",
" (82, u'class_text', u'Iris-setosa', 0.0014980072, 3),\n",
" (84, u'class_text', u'Iris-virginica', 0.743757, 1),\n",
" (84, u'class_text', u'Iris-versicolor', 0.25624287, 2),\n",
" (84, u'class_text', u'Iris-setosa', 1.1804799e-07, 3),\n",
" (92, u'class_text', u'Iris-versicolor', 0.9489498, 1),\n",
" (92, u'class_text', u'Iris-virginica', 0.050999135, 2),\n",
" (92, u'class_text', u'Iris-setosa', 5.1051586e-05, 3),\n",
" (98, u'class_text', u'Iris-versicolor', 0.9882598, 1),\n",
" (98, u'class_text', u'Iris-virginica', 0.011410431, 2),\n",
" (98, u'class_text', u'Iris-setosa', 0.00032975432, 3),\n",
" (99, u'class_text', u'Iris-versicolor', 0.7122672, 1),\n",
" (99, u'class_text', u'Iris-setosa', 0.2864844, 2),\n",
" (99, u'class_text', u'Iris-virginica', 0.0012483773, 3),\n",
" (102, u'class_text', u'Iris-virginica', 0.8344315, 1),\n",
" (102, u'class_text', u'Iris-versicolor', 0.16556835, 2),\n",
" (102, u'class_text', u'Iris-setosa', 4.9313943e-08, 3),\n",
" (107, u'class_text', u'Iris-virginica', 0.7617606, 1),\n",
" (107, u'class_text', u'Iris-versicolor', 0.23823881, 2),\n",
" (107, u'class_text', u'Iris-setosa', 6.2156596e-07, 3),\n",
" (114, u'class_text', u'Iris-virginica', 0.85601324, 1),\n",
" (114, u'class_text', u'Iris-versicolor', 0.1439867, 2),\n",
" (114, u'class_text', u'Iris-setosa', 3.4068247e-08, 3),\n",
" (117, u'class_text', u'Iris-virginica', 0.76065344, 1),\n",
" (117, u'class_text', u'Iris-versicolor', 0.23934652, 2),\n",
" (117, u'class_text', u'Iris-setosa', 4.0775706e-08, 3),\n",
" (121, u'class_text', u'Iris-virginica', 0.65924823, 1),\n",
" (121, u'class_text', u'Iris-versicolor', 0.34075174, 2),\n",
" (121, u'class_text', u'Iris-setosa', 3.7877243e-08, 3),\n",
" (123, u'class_text', u'Iris-virginica', 0.968423, 1),\n",
" (123, u'class_text', u'Iris-versicolor', 0.031577036, 2),\n",
" (123, u'class_text', u'Iris-setosa', 1.5606285e-11, 3),\n",
" (125, u'class_text', u'Iris-virginica', 0.72842705, 1),\n",
" (125, u'class_text', u'Iris-versicolor', 0.2715729, 2),\n",
" (125, u'class_text', u'Iris-setosa', 3.7875385e-08, 3),\n",
" (127, u'class_text', u'Iris-versicolor', 0.8053533, 1),\n",
" (127, u'class_text', u'Iris-virginica', 0.19464317, 2),\n",
" (127, u'class_text', u'Iris-setosa', 3.5179064e-06, 3),\n",
" (145, u'class_text', u'Iris-virginica', 0.7297866, 1),\n",
" (145, u'class_text', u'Iris-versicolor', 0.2702134, 2),\n",
" (145, u'class_text', u'Iris-setosa', 2.8784607e-08, 3),\n",
" (147, u'class_text', u'Iris-virginica', 0.5341273, 1),\n",
" (147, u'class_text', u'Iris-versicolor', 0.4658725, 2),\n",
" (147, u'class_text', u'Iris-setosa', 2.3799986e-07, 3),\n",
" (148, u'class_text', u'Iris-versicolor', 0.6266347, 1),\n",
" (148, u'class_text', u'Iris-virginica', 0.3733647, 2),\n",
" (148, u'class_text', u'Iris-setosa', 5.7692125e-07, 3),\n",
" (149, u'class_text', u'Iris-virginica', 0.5517554, 1),\n",
" (149, u'class_text', u'Iris-versicolor', 0.4482443, 2),\n",
" (149, u'class_text', u'Iris-setosa', 3.1108453e-07, 3)]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_multi_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict', -- output table\n",
" 'prob', -- prediction type\n",
" FALSE, -- use gpus\n",
" 3 -- mst_key to use\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id, rank;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"warm_start\"></a>\n",
"# 3. Warm start\n",
"\n",
"Next, use the warm_start parameter to continue learning, using the coefficients from the run above. Note that we don't drop the model table or model summary table:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit_multiple_model</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', -- source_table\n",
" 'iris_multi_model', -- model_output_table\n",
" 'mst_table', -- model_selection_table\n",
" 3, -- num_iterations\n",
" FALSE, -- use gpus\n",
" 'iris_test_packed', -- validation dataset\n",
" 1, -- metrics compute frequency\n",
" TRUE, -- warm start\n",
" 'Sophie L.', -- name\n",
" 'Simple MLP for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View summary:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>validation_table</th>\n",
" <th>model</th>\n",
" <th>model_info</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_selection_table</th>\n",
" <th>object_table</th>\n",
" <th>num_iterations</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>warm_start</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_text_class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_test_packed</td>\n",
" <td>iris_multi_model</td>\n",
" <td>iris_multi_model_info</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>mst_table</td>\n",
" <td>None</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>True</td>\n",
" <td>Sophie L.</td>\n",
" <td>Simple MLP for iris dataset</td>\n",
" <td>2021-03-06 00:55:34.010762</td>\n",
" <td>2021-03-06 00:56:20.576330</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[1]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[1, 2, 3]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_test_packed', u'iris_multi_model', u'iris_multi_model_info', [u'class_text'], [u'attributes'], u'model_arch_library', u'mst_table', None, 3, 1, True, u'Sophie L.', u'Simple MLP for iris dataset', datetime.datetime(2021, 3, 6, 0, 55, 34, 10762), datetime.datetime(2021, 3, 6, 0, 56, 20, 576330), u'1.18.0-dev', [1], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], [u'character varying'], 1.0, [1, 2, 3])]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View performance of each model:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[12.8246030807495, 28.3149819374084, 43.8511519432068]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.949999988079</td>\n",
" <td>0.125932246447</td>\n",
" <td>[0.983333349227905, 0.908333361148834, 0.949999988079071]</td>\n",
" <td>[0.0759517326951027, 0.280529856681824, 0.125932246446609]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.262804627419</td>\n",
" <td>[0.966666638851166, 0.933333337306976, 0.966666638851166]</td>\n",
" <td>[0.115140154957771, 0.282798647880554, 0.262804627418518]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[12.3267669677734, 27.5790538787842, 43.3719210624695]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.958333313465</td>\n",
" <td>0.646220803261</td>\n",
" <td>[0.916666686534882, 0.774999976158142, 0.958333313465118]</td>\n",
" <td>[0.760809063911438, 0.70676600933075, 0.646220803260803]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.676706075668</td>\n",
" <td>[0.899999976158142, 0.699999988079071, 0.966666638851166]</td>\n",
" <td>[0.789911270141602, 0.741125166416168, 0.676706075668335]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[13.8655989170074, 29.3921880722046, 45.186311006546]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.161019146442</td>\n",
" <td>[0.608333349227905, 0.975000023841858, 0.966666638851166]</td>\n",
" <td>[0.656926870346069, 0.154457986354828, 0.161019146442413]</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.184286847711</td>\n",
" <td>[0.666666686534882, 0.966666638851166, 0.966666638851166]</td>\n",
" <td>[0.60343611240387, 0.166501134634018, 0.184286847710609]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[12.5584180355072, 27.7957689762115, 43.5938129425049]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.925000011921</td>\n",
" <td>0.125614732504</td>\n",
" <td>[0.850000023841858, 0.908333361148834, 0.925000011920929]</td>\n",
" <td>[0.311796188354492, 0.228279903531075, 0.125614732503891]</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.205575048923</td>\n",
" <td>[0.699999988079071, 0.899999976158142, 0.933333337306976]</td>\n",
" <td>[0.434732705354691, 0.278642177581787, 0.205575048923492]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[14.3016650676727, 29.8289239406586, 45.6773319244385]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.916666686535</td>\n",
" <td>0.680241525173</td>\n",
" <td>[0.899999976158142, 0.899999976158142, 0.916666686534882]</td>\n",
" <td>[0.75947380065918, 0.717410624027252, 0.680241525173187]</td>\n",
" <td>0.933333337307</td>\n",
" <td>0.685820519924</td>\n",
" <td>[0.933333337306976, 0.933333337306976, 0.933333337306976]</td>\n",
" <td>[0.764581918716431, 0.718774557113647, 0.685820519924164]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[13.6457929611206, 29.1624140739441, 44.9534199237823]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.891666650772</td>\n",
" <td>0.590237081051</td>\n",
" <td>[0.824999988079071, 0.783333361148834, 0.891666650772095]</td>\n",
" <td>[0.666068911552429, 0.633061707019806, 0.590237081050873]</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.576045572758</td>\n",
" <td>[0.866666674613953, 0.866666674613953, 0.899999976158142]</td>\n",
" <td>[0.645683944225311, 0.608498632907867, 0.576045572757721]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[14.0837008953094, 29.6097829341888, 45.4142129421234]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.916666686535</td>\n",
" <td>0.174454689026</td>\n",
" <td>[0.949999988079071, 0.958333313465118, 0.916666686534882]</td>\n",
" <td>[0.166735425591469, 0.141851797699928, 0.174454689025879]</td>\n",
" <td>0.899999976158</td>\n",
" <td>0.219132959843</td>\n",
" <td>[0.966666638851166, 0.933333337306976, 0.899999976158142]</td>\n",
" <td>[0.186790466308594, 0.176578417420387, 0.219132959842682]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[13.1594960689545, 28.5860660076141, 44.1881170272827]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.285291582346</td>\n",
" <td>[0.774999976158142, 0.949999988079071, 0.866666674613953]</td>\n",
" <td>[0.441815197467804, 0.140827313065529, 0.285291582345963]</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.246576815844</td>\n",
" <td>[0.766666650772095, 0.966666638851166, 0.866666674613953]</td>\n",
" <td>[0.4128278195858, 0.146319955587387, 0.246576815843582]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[14.5546190738678, 30.0798380374908, 45.94082903862]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.850000023842</td>\n",
" <td>0.675731360912</td>\n",
" <td>[0.791666686534882, 0.841666638851166, 0.850000023841858]</td>\n",
" <td>[0.746130049228668, 0.706377267837524, 0.675731360912323]</td>\n",
" <td>0.866666674614</td>\n",
" <td>0.650432705879</td>\n",
" <td>[0.866666674613953, 0.866666674613953, 0.866666674613953]</td>\n",
" <td>[0.712817847728729, 0.677974581718445, 0.650432705879211]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[15.3575170040131, 30.5435180664062, 46.5635209083557]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.658333361149</td>\n",
" <td>0.45723798871</td>\n",
" <td>[0.658333361148834, 0.683333337306976, 0.658333361148834]</td>\n",
" <td>[0.457635939121246, 0.455960959196091, 0.457237988710403]</td>\n",
" <td>0.699999988079</td>\n",
" <td>0.48275628686</td>\n",
" <td>[0.699999988079071, 0.600000023841858, 0.699999988079071]</td>\n",
" <td>[0.48207613825798, 0.491984754800797, 0.482756286859512]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=4</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.75390625</td>\n",
" <td>[14.8466219902039, 30.2953569889069, 46.1656670570374]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.683333337307</td>\n",
" <td>0.456283688545</td>\n",
" <td>[0.925000011920929, 0.899999976158142, 0.683333337306976]</td>\n",
" <td>[0.224153310060501, 0.295417010784149, 0.456283688545227]</td>\n",
" <td>0.600000023842</td>\n",
" <td>0.494575560093</td>\n",
" <td>[0.966666638851166, 0.899999976158142, 0.600000023841858]</td>\n",
" <td>[0.227903217077255, 0.345975488424301, 0.494575560092926]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'</td>\n",
" <td>epochs=1,batch_size=8</td>\n",
" <td>madlib_keras</td>\n",
" <td>1.18359375</td>\n",
" <td>[13.4095330238342, 28.938658952713, 44.7153990268707]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.691666662693</td>\n",
" <td>0.528191685677</td>\n",
" <td>[0.708333313465118, 0.966666638851166, 0.691666662693024]</td>\n",
" <td>[0.395545929670334, 0.100506067276001, 0.528191685676575]</td>\n",
" <td>0.566666662693</td>\n",
" <td>0.720313131809</td>\n",
" <td>[0.633333325386047, 0.966666638851166, 0.566666662693024]</td>\n",
" <td>[0.508394777774811, 0.130626574158669, 0.720313131809235]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(9, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [12.8246030807495, 28.3149819374084, 43.8511519432068], [u'accuracy'], u'categorical_crossentropy', 0.949999988079071, 0.125932246446609, [0.983333349227905, 0.908333361148834, 0.949999988079071], [0.0759517326951027, 0.280529856681824, 0.125932246446609], 0.966666638851166, 0.262804627418518, [0.966666638851166, 0.933333337306976, 0.966666638851166], [0.115140154957771, 0.282798647880554, 0.262804627418518]),\n",
" (7, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [12.3267669677734, 27.5790538787842, 43.3719210624695], [u'accuracy'], u'categorical_crossentropy', 0.958333313465118, 0.646220803260803, [0.916666686534882, 0.774999976158142, 0.958333313465118], [0.760809063911438, 0.70676600933075, 0.646220803260803], 0.966666638851166, 0.676706075668335, [0.899999976158142, 0.699999988079071, 0.966666638851166], [0.789911270141602, 0.741125166416168, 0.676706075668335]),\n",
" (6, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [13.8655989170074, 29.3921880722046, 45.186311006546], [u'accuracy'], u'categorical_crossentropy', 0.966666638851166, 0.161019146442413, [0.608333349227905, 0.975000023841858, 0.966666638851166], [0.656926870346069, 0.154457986354828, 0.161019146442413], 0.966666638851166, 0.184286847710609, [0.666666686534882, 0.966666638851166, 0.966666638851166], [0.60343611240387, 0.166501134634018, 0.184286847710609]),\n",
" (3, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [12.5584180355072, 27.7957689762115, 43.5938129425049], [u'accuracy'], u'categorical_crossentropy', 0.925000011920929, 0.125614732503891, [0.850000023841858, 0.908333361148834, 0.925000011920929], [0.311796188354492, 0.228279903531075, 0.125614732503891], 0.933333337306976, 0.205575048923492, [0.699999988079071, 0.899999976158142, 0.933333337306976], [0.434732705354691, 0.278642177581787, 0.205575048923492]),\n",
" (1, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [14.3016650676727, 29.8289239406586, 45.6773319244385], [u'accuracy'], u'categorical_crossentropy', 0.916666686534882, 0.680241525173187, [0.899999976158142, 0.899999976158142, 0.916666686534882], [0.75947380065918, 0.717410624027252, 0.680241525173187], 0.933333337306976, 0.685820519924164, [0.933333337306976, 0.933333337306976, 0.933333337306976], [0.764581918716431, 0.718774557113647, 0.685820519924164]),\n",
" (2, 1, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [13.6457929611206, 29.1624140739441, 44.9534199237823], [u'accuracy'], u'categorical_crossentropy', 0.891666650772095, 0.590237081050873, [0.824999988079071, 0.783333361148834, 0.891666650772095], [0.666068911552429, 0.633061707019806, 0.590237081050873], 0.899999976158142, 0.576045572757721, [0.866666674613953, 0.866666674613953, 0.899999976158142], [0.645683944225311, 0.608498632907867, 0.576045572757721]),\n",
" (4, 1, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 0.75390625, [14.0837008953094, 29.6097829341888, 45.4142129421234], [u'accuracy'], u'categorical_crossentropy', 0.916666686534882, 0.174454689025879, [0.949999988079071, 0.958333313465118, 0.916666686534882], [0.166735425591469, 0.141851797699928, 0.174454689025879], 0.899999976158142, 0.219132959842682, [0.966666638851166, 0.933333337306976, 0.899999976158142], [0.186790466308594, 0.176578417420387, 0.219132959842682]),\n",
" (10, 2, u\"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [13.1594960689545, 28.5860660076141, 44.1881170272827], [u'accuracy'], u'categorical_crossentropy', 0.866666674613953, 0.285291582345963, [0.774999976158142, 0.949999988079071, 0.866666674613953], [0.441815197467804, 0.140827313065529, 0.285291582345963], 0.866666674613953, 0.246576815843582, [0.766666650772095, 0.966666638851166, 0.866666674613953], [0.4128278195858, 0.146319955587387, 0.246576815843582]),\n",
" (8, 2, u\"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [14.5546190738678, 30.0798380374908, 45.94082903862], [u'accuracy'], u'categorical_crossentropy', 0.850000023841858, 0.675731360912323, [0.791666686534882, 0.841666638851166, 0.850000023841858], [0.746130049228668, 0.706377267837524, 0.675731360912323], 0.866666674613953, 0.650432705879211, [0.866666674613953, 0.866666674613953, 0.866666674613953], [0.712817847728729, 0.677974581718445, 0.650432705879211]),\n",
" (11, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 1.18359375, [15.3575170040131, 30.5435180664062, 46.5635209083557], [u'accuracy'], u'categorical_crossentropy', 0.658333361148834, 0.457237988710403, [0.658333361148834, 0.683333337306976, 0.658333361148834], [0.457635939121246, 0.455960959196091, 0.457237988710403], 0.699999988079071, 0.482756286859512, [0.699999988079071, 0.600000023841858, 0.699999988079071], [0.48207613825798, 0.491984754800797, 0.482756286859512]),\n",
" (5, 1, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=4', u'madlib_keras', 0.75390625, [14.8466219902039, 30.2953569889069, 46.1656670570374], [u'accuracy'], u'categorical_crossentropy', 0.683333337306976, 0.456283688545227, [0.925000011920929, 0.899999976158142, 0.683333337306976], [0.224153310060501, 0.295417010784149, 0.456283688545227], 0.600000023841858, 0.494575560092926, [0.966666638851166, 0.899999976158142, 0.600000023841858], [0.227903217077255, 0.345975488424301, 0.494575560092926]),\n",
" (12, 2, u\"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'\", u'epochs=1,batch_size=8', u'madlib_keras', 1.18359375, [13.4095330238342, 28.938658952713, 44.7153990268707], [u'accuracy'], u'categorical_crossentropy', 0.691666662693024, 0.528191685676575, [0.708333313465118, 0.966666638851166, 0.691666662693024], [0.395545929670334, 0.100506067276001, 0.528191685676575], 0.566666662693024, 0.720313131809235, [0.633333325386047, 0.966666638851166, 0.566666662693024], [0.508394777774811, 0.130626574158669, 0.720313131809235])]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot validation results:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"720\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"df_summary = %sql SELECT * FROM iris_multi_model_summary;\n",
"df_summary = df_summary.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"iters = df_summary['metrics_iters'][0]\n",
"\n",
"for mst_key in df_results['mst_key']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" \n",
" ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')\n",
" ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')\n",
"\n",
"plt.legend();\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 1
}