blob: 8dfa6cdc255f101f3678dbb329f8d3c1f9c441d0 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multilayer Perceptron Using Keras and MADlib\n",
"\n",
"E2E classification example using MADlib calling a Keras MLP.\n",
"\n",
"Deep learning works best on very large datasets, but that is not convenient for a quick introduction to the syntax. So in this workbook we use the well known iris data set from https://archive.ics.uci.edu/ml/datasets/iris to help get you started. It is similar to the example in user docs http://madlib.apache.org/docs/latest/index.html\n",
"\n",
"For more realistic examples with images please refer to the deep learning notebooks at\n",
"https://github.com/apache/madlib-site/tree/asf-site/community-artifacts\n",
"\n",
"## Table of contents\n",
"\n",
"<a href=\"#class\">Classification</a>\n",
"\n",
"* <a href=\"#create_input_data\">1. Create input data</a>\n",
"\n",
"* <a href=\"#pp\">2. Call preprocessor for deep learning</a>\n",
"\n",
"* <a href=\"#load\">3. Define and load model architecture</a>\n",
"\n",
"* <a href=\"#train\">4. Train</a>\n",
"\n",
"* <a href=\"#eval\">5. Evaluate</a>\n",
"\n",
"* <a href=\"#pred\">6. Predict</a>\n",
"\n",
"* <a href=\"#pred_byom\">7. Predict BYOM</a>\n",
"\n",
"<a href=\"#class2\">Classification with Other Parameters</a>\n",
"\n",
"* <a href=\"#val_dataset\">1. Validation dataset</a>\n",
"\n",
"* <a href=\"#pred_prob\">2. Predict probabilities</a>\n",
"\n",
"* <a href=\"#warm_start\">3. Warm start</a>\n",
"\n",
"<a href=\"#transfer_learn\">Transfer learning</a>\n",
"\n",
"* <a href=\"#load2\">1. Define and load model architecture with some layers frozen</a>\n",
"\n",
"* <a href=\"#train2\">2. Train transfer model</a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Greenplum Database 5.x on GCP - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.18.0-dev, git revision: rel/v1.17.0-89-g14a91ce, cmake configuration time: Fri Mar 5 23:08:38 UTC 2021, build type: release, build system: Linux-3.10.0-1160.11.1.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.18.0-dev, git revision: rel/v1.17.0-89-g14a91ce, cmake configuration time: Fri Mar 5 23:08:38 UTC 2021, build type: release, build system: Linux-3.10.0-1160.11.1.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class\"></a>\n",
"# Classification\n",
"\n",
"<a id=\"create_input_data\"></a>\n",
"# 1. Create input data\n",
"\n",
"Load iris data set."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"150 rows affected.\n",
"150 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>attributes</th>\n",
" <th>class_text</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>[Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>[Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>[Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>[Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>[Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>[Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>[Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>[Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>[Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>[Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>[Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>[Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32</td>\n",
" <td>[Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>[Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>[Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37</td>\n",
" <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39</td>\n",
" <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42</td>\n",
" <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>46</td>\n",
" <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>47</td>\n",
" <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>48</td>\n",
" <td>[Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>[Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>[Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')]</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>[Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>52</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>54</td>\n",
" <td>[Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>[Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>58</td>\n",
" <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>63</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>65</td>\n",
" <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>66</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>68</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>[Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>[Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>71</td>\n",
" <td>[Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>73</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>74</td>\n",
" <td>[Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>[Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>[Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>[Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>78</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>79</td>\n",
" <td>[Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>[Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>81</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>83</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>[Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>85</td>\n",
" <td>[Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>86</td>\n",
" <td>[Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>87</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>88</td>\n",
" <td>[Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>[Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>91</td>\n",
" <td>[Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>93</td>\n",
" <td>[Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>[Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>95</td>\n",
" <td>[Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>96</td>\n",
" <td>[Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>[Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>[Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>[Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')]</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>[Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>[Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>104</td>\n",
" <td>[Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>105</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>106</td>\n",
" <td>[Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>107</td>\n",
" <td>[Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>108</td>\n",
" <td>[Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>109</td>\n",
" <td>[Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>110</td>\n",
" <td>[Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>111</td>\n",
" <td>[Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>[Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>[Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>114</td>\n",
" <td>[Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>[Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>[Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>118</td>\n",
" <td>[Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>[Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>[Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>121</td>\n",
" <td>[Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>122</td>\n",
" <td>[Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>123</td>\n",
" <td>[Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>124</td>\n",
" <td>[Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>126</td>\n",
" <td>[Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>[Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>128</td>\n",
" <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>129</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>130</td>\n",
" <td>[Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>131</td>\n",
" <td>[Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>132</td>\n",
" <td>[Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>133</td>\n",
" <td>[Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>134</td>\n",
" <td>[Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>135</td>\n",
" <td>[Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>[Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>137</td>\n",
" <td>[Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>138</td>\n",
" <td>[Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>139</td>\n",
" <td>[Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>140</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>141</td>\n",
" <td>[Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>142</td>\n",
" <td>[Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>143</td>\n",
" <td>[Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>[Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>[Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>[Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>[Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>148</td>\n",
" <td>[Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>149</td>\n",
" <td>[Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <td>150</td>\n",
" <td>[Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')]</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (2, [Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (3, [Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (4, [Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (5, [Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (6, [Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')], u'Iris-setosa'),\n",
" (7, [Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (8, [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (9, [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (10, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (11, [Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (12, [Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (13, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')], u'Iris-setosa'),\n",
" (14, [Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')], u'Iris-setosa'),\n",
" (15, [Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (16, [Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (17, [Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')], u'Iris-setosa'),\n",
" (18, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (19, [Decimal('5.7'), Decimal('3.8'), Decimal('1.7'), Decimal('0.3')], u'Iris-setosa'),\n",
" (20, [Decimal('5.1'), Decimal('3.8'), Decimal('1.5'), Decimal('0.3')], u'Iris-setosa'),\n",
" (21, [Decimal('5.4'), Decimal('3.4'), Decimal('1.7'), Decimal('0.2')], u'Iris-setosa'),\n",
" (22, [Decimal('5.1'), Decimal('3.7'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (23, [Decimal('4.6'), Decimal('3.6'), Decimal('1.0'), Decimal('0.2')], u'Iris-setosa'),\n",
" (24, [Decimal('5.1'), Decimal('3.3'), Decimal('1.7'), Decimal('0.5')], u'Iris-setosa'),\n",
" (25, [Decimal('4.8'), Decimal('3.4'), Decimal('1.9'), Decimal('0.2')], u'Iris-setosa'),\n",
" (26, [Decimal('5.0'), Decimal('3.0'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (27, [Decimal('5.0'), Decimal('3.4'), Decimal('1.6'), Decimal('0.4')], u'Iris-setosa'),\n",
" (28, [Decimal('5.2'), Decimal('3.5'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (29, [Decimal('5.2'), Decimal('3.4'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (30, [Decimal('4.7'), Decimal('3.2'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (31, [Decimal('4.8'), Decimal('3.1'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (32, [Decimal('5.4'), Decimal('3.4'), Decimal('1.5'), Decimal('0.4')], u'Iris-setosa'),\n",
" (33, [Decimal('5.2'), Decimal('4.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (34, [Decimal('5.5'), Decimal('4.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (35, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (36, [Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')], u'Iris-setosa'),\n",
" (37, [Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (38, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa'),\n",
" (39, [Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (40, [Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (41, [Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (42, [Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')], u'Iris-setosa'),\n",
" (43, [Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa'),\n",
" (44, [Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')], u'Iris-setosa'),\n",
" (45, [Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')], u'Iris-setosa'),\n",
" (46, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa'),\n",
" (47, [Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')], u'Iris-setosa'),\n",
" (48, [Decimal('4.6'), Decimal('3.2'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (49, [Decimal('5.3'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa'),\n",
" (50, [Decimal('5.0'), Decimal('3.3'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa'),\n",
" (51, [Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (52, [Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (53, [Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (54, [Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (55, [Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (56, [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (57, [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (58, [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (59, [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (60, [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (61, [Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (62, [Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (63, [Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (64, [Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (65, [Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (66, [Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (67, [Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (68, [Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (69, [Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (70, [Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (71, [Decimal('5.9'), Decimal('3.2'), Decimal('4.8'), Decimal('1.8')], u'Iris-versicolor'),\n",
" (72, [Decimal('6.1'), Decimal('2.8'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (73, [Decimal('6.3'), Decimal('2.5'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (74, [Decimal('6.1'), Decimal('2.8'), Decimal('4.7'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (75, [Decimal('6.4'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (76, [Decimal('6.6'), Decimal('3.0'), Decimal('4.4'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (77, [Decimal('6.8'), Decimal('2.8'), Decimal('4.8'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (78, [Decimal('6.7'), Decimal('3.0'), Decimal('5.0'), Decimal('1.7')], u'Iris-versicolor'),\n",
" (79, [Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (80, [Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (81, [Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (82, [Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (83, [Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (84, [Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (85, [Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (86, [Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')], u'Iris-versicolor'),\n",
" (87, [Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')], u'Iris-versicolor'),\n",
" (88, [Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (89, [Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (90, [Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (91, [Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (92, [Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')], u'Iris-versicolor'),\n",
" (93, [Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (94, [Decimal('5.0'), Decimal('2.3'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor'),\n",
" (95, [Decimal('5.6'), Decimal('2.7'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (96, [Decimal('5.7'), Decimal('3.0'), Decimal('4.2'), Decimal('1.2')], u'Iris-versicolor'),\n",
" (97, [Decimal('5.7'), Decimal('2.9'), Decimal('4.2'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (98, [Decimal('6.2'), Decimal('2.9'), Decimal('4.3'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (99, [Decimal('5.1'), Decimal('2.5'), Decimal('3.0'), Decimal('1.1')], u'Iris-versicolor'),\n",
" (100, [Decimal('5.7'), Decimal('2.8'), Decimal('4.1'), Decimal('1.3')], u'Iris-versicolor'),\n",
" (101, [Decimal('6.3'), Decimal('3.3'), Decimal('6.0'), Decimal('2.5')], u'Iris-virginica'),\n",
" (102, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (103, [Decimal('7.1'), Decimal('3.0'), Decimal('5.9'), Decimal('2.1')], u'Iris-virginica'),\n",
" (104, [Decimal('6.3'), Decimal('2.9'), Decimal('5.6'), Decimal('1.8')], u'Iris-virginica'),\n",
" (105, [Decimal('6.5'), Decimal('3.0'), Decimal('5.8'), Decimal('2.2')], u'Iris-virginica'),\n",
" (106, [Decimal('7.6'), Decimal('3.0'), Decimal('6.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (107, [Decimal('4.9'), Decimal('2.5'), Decimal('4.5'), Decimal('1.7')], u'Iris-virginica'),\n",
" (108, [Decimal('7.3'), Decimal('2.9'), Decimal('6.3'), Decimal('1.8')], u'Iris-virginica'),\n",
" (109, [Decimal('6.7'), Decimal('2.5'), Decimal('5.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (110, [Decimal('7.2'), Decimal('3.6'), Decimal('6.1'), Decimal('2.5')], u'Iris-virginica'),\n",
" (111, [Decimal('6.5'), Decimal('3.2'), Decimal('5.1'), Decimal('2.0')], u'Iris-virginica'),\n",
" (112, [Decimal('6.4'), Decimal('2.7'), Decimal('5.3'), Decimal('1.9')], u'Iris-virginica'),\n",
" (113, [Decimal('6.8'), Decimal('3.0'), Decimal('5.5'), Decimal('2.1')], u'Iris-virginica'),\n",
" (114, [Decimal('5.7'), Decimal('2.5'), Decimal('5.0'), Decimal('2.0')], u'Iris-virginica'),\n",
" (115, [Decimal('5.8'), Decimal('2.8'), Decimal('5.1'), Decimal('2.4')], u'Iris-virginica'),\n",
" (116, [Decimal('6.4'), Decimal('3.2'), Decimal('5.3'), Decimal('2.3')], u'Iris-virginica'),\n",
" (117, [Decimal('6.5'), Decimal('3.0'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (118, [Decimal('7.7'), Decimal('3.8'), Decimal('6.7'), Decimal('2.2')], u'Iris-virginica'),\n",
" (119, [Decimal('7.7'), Decimal('2.6'), Decimal('6.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (120, [Decimal('6.0'), Decimal('2.2'), Decimal('5.0'), Decimal('1.5')], u'Iris-virginica'),\n",
" (121, [Decimal('6.9'), Decimal('3.2'), Decimal('5.7'), Decimal('2.3')], u'Iris-virginica'),\n",
" (122, [Decimal('5.6'), Decimal('2.8'), Decimal('4.9'), Decimal('2.0')], u'Iris-virginica'),\n",
" (123, [Decimal('7.7'), Decimal('2.8'), Decimal('6.7'), Decimal('2.0')], u'Iris-virginica'),\n",
" (124, [Decimal('6.3'), Decimal('2.7'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (125, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.1')], u'Iris-virginica'),\n",
" (126, [Decimal('7.2'), Decimal('3.2'), Decimal('6.0'), Decimal('1.8')], u'Iris-virginica'),\n",
" (127, [Decimal('6.2'), Decimal('2.8'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (128, [Decimal('6.1'), Decimal('3.0'), Decimal('4.9'), Decimal('1.8')], u'Iris-virginica'),\n",
" (129, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.1')], u'Iris-virginica'),\n",
" (130, [Decimal('7.2'), Decimal('3.0'), Decimal('5.8'), Decimal('1.6')], u'Iris-virginica'),\n",
" (131, [Decimal('7.4'), Decimal('2.8'), Decimal('6.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (132, [Decimal('7.9'), Decimal('3.8'), Decimal('6.4'), Decimal('2.0')], u'Iris-virginica'),\n",
" (133, [Decimal('6.4'), Decimal('2.8'), Decimal('5.6'), Decimal('2.2')], u'Iris-virginica'),\n",
" (134, [Decimal('6.3'), Decimal('2.8'), Decimal('5.1'), Decimal('1.5')], u'Iris-virginica'),\n",
" (135, [Decimal('6.1'), Decimal('2.6'), Decimal('5.6'), Decimal('1.4')], u'Iris-virginica'),\n",
" (136, [Decimal('7.7'), Decimal('3.0'), Decimal('6.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (137, [Decimal('6.3'), Decimal('3.4'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (138, [Decimal('6.4'), Decimal('3.1'), Decimal('5.5'), Decimal('1.8')], u'Iris-virginica'),\n",
" (139, [Decimal('6.0'), Decimal('3.0'), Decimal('4.8'), Decimal('1.8')], u'Iris-virginica'),\n",
" (140, [Decimal('6.9'), Decimal('3.1'), Decimal('5.4'), Decimal('2.1')], u'Iris-virginica'),\n",
" (141, [Decimal('6.7'), Decimal('3.1'), Decimal('5.6'), Decimal('2.4')], u'Iris-virginica'),\n",
" (142, [Decimal('6.9'), Decimal('3.1'), Decimal('5.1'), Decimal('2.3')], u'Iris-virginica'),\n",
" (143, [Decimal('5.8'), Decimal('2.7'), Decimal('5.1'), Decimal('1.9')], u'Iris-virginica'),\n",
" (144, [Decimal('6.8'), Decimal('3.2'), Decimal('5.9'), Decimal('2.3')], u'Iris-virginica'),\n",
" (145, [Decimal('6.7'), Decimal('3.3'), Decimal('5.7'), Decimal('2.5')], u'Iris-virginica'),\n",
" (146, [Decimal('6.7'), Decimal('3.0'), Decimal('5.2'), Decimal('2.3')], u'Iris-virginica'),\n",
" (147, [Decimal('6.3'), Decimal('2.5'), Decimal('5.0'), Decimal('1.9')], u'Iris-virginica'),\n",
" (148, [Decimal('6.5'), Decimal('3.0'), Decimal('5.2'), Decimal('2.0')], u'Iris-virginica'),\n",
" (149, [Decimal('6.2'), Decimal('3.4'), Decimal('5.4'), Decimal('2.3')], u'Iris-virginica'),\n",
" (150, [Decimal('5.9'), Decimal('3.0'), Decimal('5.1'), Decimal('1.8')], u'Iris-virginica')]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS iris_data;\n",
"\n",
"CREATE TABLE iris_data(\n",
" id serial,\n",
" attributes numeric[],\n",
" class_text varchar\n",
");\n",
"\n",
"INSERT INTO iris_data(id, attributes, class_text) VALUES\n",
"(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),\n",
"(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),\n",
"(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),\n",
"(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),\n",
"(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),\n",
"(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),\n",
"(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),\n",
"(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),\n",
"(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),\n",
"(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),\n",
"(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),\n",
"(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),\n",
"(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),\n",
"(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),\n",
"(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),\n",
"(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),\n",
"(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),\n",
"(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),\n",
"(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),\n",
"(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),\n",
"(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),\n",
"(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),\n",
"(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),\n",
"(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),\n",
"(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),\n",
"(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),\n",
"(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),\n",
"(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),\n",
"(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),\n",
"(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),\n",
"(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),\n",
"(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),\n",
"(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),\n",
"(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),\n",
"(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),\n",
"(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),\n",
"(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),\n",
"(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),\n",
"(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),\n",
"(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),\n",
"(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),\n",
"(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),\n",
"(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),\n",
"(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),\n",
"(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),\n",
"(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),\n",
"(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),\n",
"(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),\n",
"(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),\n",
"(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),\n",
"(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),\n",
"(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),\n",
"(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),\n",
"(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),\n",
"(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),\n",
"(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),\n",
"(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),\n",
"(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),\n",
"(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),\n",
"(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),\n",
"(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),\n",
"(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),\n",
"(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),\n",
"(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),\n",
"(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),\n",
"(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),\n",
"(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),\n",
"(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),\n",
"(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),\n",
"(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),\n",
"(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),\n",
"(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),\n",
"(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),\n",
"(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),\n",
"(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),\n",
"(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),\n",
"(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),\n",
"(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),\n",
"(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),\n",
"(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),\n",
"(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),\n",
"(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),\n",
"(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),\n",
"(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),\n",
"(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),\n",
"(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),\n",
"(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),\n",
"(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),\n",
"(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),\n",
"(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),\n",
"(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),\n",
"(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),\n",
"(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),\n",
"(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),\n",
"(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),\n",
"(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),\n",
"(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),\n",
"(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),\n",
"(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),\n",
"(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),\n",
"(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),\n",
"(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),\n",
"(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),\n",
"(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),\n",
"(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),\n",
"(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),\n",
"(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),\n",
"(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),\n",
"(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),\n",
"(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),\n",
"(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),\n",
"(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),\n",
"(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),\n",
"(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),\n",
"(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),\n",
"(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),\n",
"(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),\n",
"(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),\n",
"(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),\n",
"(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),\n",
"(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),\n",
"(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),\n",
"(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),\n",
"(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),\n",
"(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),\n",
"(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),\n",
"(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),\n",
"(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),\n",
"(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),\n",
"(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),\n",
"(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),\n",
"(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),\n",
"(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),\n",
"(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),\n",
"(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),\n",
"(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),\n",
"(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),\n",
"(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),\n",
"(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),\n",
"(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),\n",
"(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),\n",
"(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),\n",
"(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),\n",
"(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),\n",
"(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');\n",
"\n",
"SELECT * FROM iris_data ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a test/validation dataset from the training data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(120L,)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train, iris_test;\n",
"\n",
"-- Set seed so results are reproducible\n",
"SELECT setseed(0);\n",
"\n",
"SELECT madlib.train_test_split('iris_data', -- Source table\n",
" 'iris', -- Output table root name\n",
" 0.8, -- Train proportion\n",
" NULL, -- Test proportion (0.2)\n",
" NULL, -- Strata definition\n",
" NULL, -- Output all columns\n",
" NULL, -- Sample without replacement\n",
" TRUE -- Separate output tables\n",
" );\n",
"\n",
"SELECT COUNT(*) FROM iris_train;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pp\"></a>\n",
"# 2. Call preprocessor for deep learning\n",
"Training dataset (uses training preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_text_class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train</td>\n",
" <td>iris_train_packed</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>60</td>\n",
" <td>1.0</td>\n",
" <td>[3]</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train', u'iris_train_packed', [u'class_text'], [u'attributes'], [u'character varying'], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 60, 1.0, [3], 'all_segments', 'all_segments')]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('iris_train', -- Source table\n",
" 'iris_train_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes' -- Independent variable\n",
" );\n",
"\n",
"SELECT * FROM iris_train_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset (uses validation preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_text_class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_test</td>\n",
" <td>iris_test_packed</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>[u'character varying']</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" <td>15</td>\n",
" <td>1.0</td>\n",
" <td>[3]</td>\n",
" <td>all_segments</td>\n",
" <td>all_segments</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_test', u'iris_test_packed', [u'class_text'], [u'attributes'], [u'character varying'], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'], 15, 1.0, [3], 'all_segments', 'all_segments')]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('iris_test', -- Source table\n",
" 'iris_test_packed', -- Output table\n",
" 'class_text', -- Dependent variable\n",
" 'attributes', -- Independent variable\n",
" 'iris_train_packed' -- From training preprocessor step\n",
" ); \n",
"\n",
"SELECT * FROM iris_test_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load\"></a>\n",
"# 3. Define and load model architecture\n",
"Import Keras libraries"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow import keras\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define model architecture"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/fmcquillan/Library/Python/2.7/lib/python/site-packages/tensorflow/python/ops/init_ops.py:1251: calling __init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 193\n",
"Trainable params: 193\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model_simple = Sequential()\n",
"model_simple.add(Dense(10, activation='relu', input_shape=(4,)))\n",
"model_simple.add(Dense(10, activation='relu'))\n",
"model_simple.add(Dense(3, activation='softmax'))\n",
" \n",
"model_simple.summary();"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.2.4-tf\", \"config\": {\"layers\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"name\": \"sequential\"}, \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_simple.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>model_arch</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>Sophie</td>\n",
" <td>A simple model</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u' ... (1340 characters truncated) ... s_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, u'Sophie', u'A simple model')]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS model_arch_library;\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table,\n",
" \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Sophie', -- Name\n",
" 'A simple model' -- Descr\n",
");\n",
"\n",
"SELECT model_id, model_arch, name, description FROM model_arch_library ORDER BY model_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"train\"></a>\n",
"# 4. Train\n",
"Train the model:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_model, iris_model_summary;\n",
"\n",
"SELECT madlib.madlib_keras_fit('iris_train_packed', -- source table\n",
" 'iris_model', -- model output table\n",
" 'model_arch_library', -- model arch table\n",
" 1, -- model arch id\n",
" $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params\n",
" $$ batch_size=5, epochs=3 $$, -- fit_params\n",
" 10 -- num_iterations\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>model</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>num_iterations</th>\n",
" <th>validation_table</th>\n",
" <th>object_table</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>metrics_iters</th>\n",
" <th>class_text_class_values</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_model</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>1</td>\n",
" <td> loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] </td>\n",
" <td> batch_size=5, epochs=3 </td>\n",
" <td>10</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>10</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>2021-03-06 00:27:28.144705</td>\n",
" <td>2021-03-06 00:27:31.754147</td>\n",
" <td>[3.60936093330383]</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[3]</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.916666686535</td>\n",
" <td>0.463008254766</td>\n",
" <td>[0.916666686534882]</td>\n",
" <td>[0.463008254766464]</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>[10]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_model', [u'class_text'], [u'attributes'], u'model_arch_library', 1, u\" loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] \", u' batch_size=5, epochs=3 ', 10, None, None, 10, None, None, u'madlib_keras', 0.7900390625, datetime.datetime(2021, 3, 6, 0, 27, 28, 144705), datetime.datetime(2021, 3, 6, 0, 27, 31, 754147), [3.60936093330383], u'1.18.0-dev', [3], [u'character varying'], 1.0, [u'accuracy'], u'categorical_crossentropy', 0.916666686534882, 0.463008254766464, [0.916666686534882], [0.463008254766464], None, None, None, None, [10], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'])]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"eval\"></a>\n",
"# 5. Evaluate\n",
"\n",
"Now run evaluate using model we built above:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>loss</th>\n",
" <th>metric</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.523572981358</td>\n",
" <td>0.933333337307</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.523572981357574, 0.933333337306976, [u'accuracy'], u'categorical_crossentropy')]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_validate;\n",
"\n",
"SELECT madlib.madlib_keras_evaluate('iris_model', -- model\n",
" 'iris_test_packed', -- test table\n",
" 'iris_validate' -- output table\n",
" );\n",
"\n",
"SELECT * FROM iris_validate;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred\"></a>\n",
"# 6. Predict\n",
"\n",
"Now predict using model we built. We will use the validation data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"90 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>class_name</th>\n",
" <th>class_value</th>\n",
" <th>prob</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.83670896</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.14060013</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.022690918</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.8369735</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.14013577</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.022890732</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.87973696</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.10638312</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.013879963</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.93740743</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.056862056</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.0057305074</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.83670896</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.14060013</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.022690918</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.8709096</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.11054307</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.018547323</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4681935</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4571225</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.07468399</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.45466852</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4470526</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.09827888</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.47486252</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.46100235</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.064135045</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47181308</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.43595785</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.09222904</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47956672</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.41212082</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.10831244</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.50861007</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.3626588</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.12873109</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.5061021</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.3914343</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.102463536</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.49345753</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.35755217</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.14899038</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4796765</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4385325</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0817909</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47809058</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.34930265</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.17260681</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6172143</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3620455</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.020740215</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5837618</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3847274</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.03151086</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.61951214</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3637118</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.016776035</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5954762</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.37995332</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.024570476</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.571379</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4039808</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.024640195</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.57040656</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3980587</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.03153468</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.52341586</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.43971062</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.036873452</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5800313</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3929817</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.026986998</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.72622484</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.26773784</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0060372944</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5089497</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.44541556</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.045634773</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.62922823</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.35819018</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.012581516</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6017383</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.3781529</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.020108894</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5293082</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4390557</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.031636048</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.58249867</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.39045528</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.027046034</td>\n",
" <td>3</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(10, u'class_text', u'Iris-setosa', 0.83670896, 1),\n",
" (10, u'class_text', u'Iris-versicolor', 0.14060013, 2),\n",
" (10, u'class_text', u'Iris-virginica', 0.022690918, 3),\n",
" (13, u'class_text', u'Iris-setosa', 0.8369735, 1),\n",
" (13, u'class_text', u'Iris-versicolor', 0.14013577, 2),\n",
" (13, u'class_text', u'Iris-virginica', 0.022890732, 3),\n",
" (29, u'class_text', u'Iris-setosa', 0.87973696, 1),\n",
" (29, u'class_text', u'Iris-versicolor', 0.10638312, 2),\n",
" (29, u'class_text', u'Iris-virginica', 0.013879963, 3),\n",
" (34, u'class_text', u'Iris-setosa', 0.93740743, 1),\n",
" (34, u'class_text', u'Iris-versicolor', 0.056862056, 2),\n",
" (34, u'class_text', u'Iris-virginica', 0.0057305074, 3),\n",
" (38, u'class_text', u'Iris-setosa', 0.83670896, 1),\n",
" (38, u'class_text', u'Iris-versicolor', 0.14060013, 2),\n",
" (38, u'class_text', u'Iris-virginica', 0.022690918, 3),\n",
" (43, u'class_text', u'Iris-setosa', 0.8709096, 1),\n",
" (43, u'class_text', u'Iris-versicolor', 0.11054307, 2),\n",
" (43, u'class_text', u'Iris-virginica', 0.018547323, 3),\n",
" (56, u'class_text', u'Iris-virginica', 0.4681935, 1),\n",
" (56, u'class_text', u'Iris-versicolor', 0.4571225, 2),\n",
" (56, u'class_text', u'Iris-setosa', 0.07468399, 3),\n",
" (61, u'class_text', u'Iris-versicolor', 0.45466852, 1),\n",
" (61, u'class_text', u'Iris-virginica', 0.4470526, 2),\n",
" (61, u'class_text', u'Iris-setosa', 0.09827888, 3),\n",
" (64, u'class_text', u'Iris-virginica', 0.47486252, 1),\n",
" (64, u'class_text', u'Iris-versicolor', 0.46100235, 2),\n",
" (64, u'class_text', u'Iris-setosa', 0.064135045, 3),\n",
" (67, u'class_text', u'Iris-versicolor', 0.47181308, 1),\n",
" (67, u'class_text', u'Iris-virginica', 0.43595785, 2),\n",
" (67, u'class_text', u'Iris-setosa', 0.09222904, 3),\n",
" (70, u'class_text', u'Iris-versicolor', 0.47956672, 1),\n",
" (70, u'class_text', u'Iris-virginica', 0.41212082, 2),\n",
" (70, u'class_text', u'Iris-setosa', 0.10831244, 3),\n",
" (72, u'class_text', u'Iris-versicolor', 0.50861007, 1),\n",
" (72, u'class_text', u'Iris-virginica', 0.3626588, 2),\n",
" (72, u'class_text', u'Iris-setosa', 0.12873109, 3),\n",
" (75, u'class_text', u'Iris-versicolor', 0.5061021, 1),\n",
" (75, u'class_text', u'Iris-virginica', 0.3914343, 2),\n",
" (75, u'class_text', u'Iris-setosa', 0.102463536, 3),\n",
" (89, u'class_text', u'Iris-versicolor', 0.49345753, 1),\n",
" (89, u'class_text', u'Iris-virginica', 0.35755217, 2),\n",
" (89, u'class_text', u'Iris-setosa', 0.14899038, 3),\n",
" (92, u'class_text', u'Iris-versicolor', 0.4796765, 1),\n",
" (92, u'class_text', u'Iris-virginica', 0.4385325, 2),\n",
" (92, u'class_text', u'Iris-setosa', 0.0817909, 3),\n",
" (94, u'class_text', u'Iris-versicolor', 0.47809058, 1),\n",
" (94, u'class_text', u'Iris-virginica', 0.34930265, 2),\n",
" (94, u'class_text', u'Iris-setosa', 0.17260681, 3),\n",
" (101, u'class_text', u'Iris-virginica', 0.6172143, 1),\n",
" (101, u'class_text', u'Iris-versicolor', 0.3620455, 2),\n",
" (101, u'class_text', u'Iris-setosa', 0.020740215, 3),\n",
" (102, u'class_text', u'Iris-virginica', 0.5837618, 1),\n",
" (102, u'class_text', u'Iris-versicolor', 0.3847274, 2),\n",
" (102, u'class_text', u'Iris-setosa', 0.03151086, 3),\n",
" (103, u'class_text', u'Iris-virginica', 0.61951214, 1),\n",
" (103, u'class_text', u'Iris-versicolor', 0.3637118, 2),\n",
" (103, u'class_text', u'Iris-setosa', 0.016776035, 3),\n",
" (112, u'class_text', u'Iris-virginica', 0.5954762, 1),\n",
" (112, u'class_text', u'Iris-versicolor', 0.37995332, 2),\n",
" (112, u'class_text', u'Iris-setosa', 0.024570476, 3),\n",
" (113, u'class_text', u'Iris-virginica', 0.571379, 1),\n",
" (113, u'class_text', u'Iris-versicolor', 0.4039808, 2),\n",
" (113, u'class_text', u'Iris-setosa', 0.024640195, 3),\n",
" (115, u'class_text', u'Iris-virginica', 0.57040656, 1),\n",
" (115, u'class_text', u'Iris-versicolor', 0.3980587, 2),\n",
" (115, u'class_text', u'Iris-setosa', 0.03153468, 3),\n",
" (116, u'class_text', u'Iris-virginica', 0.52341586, 1),\n",
" (116, u'class_text', u'Iris-versicolor', 0.43971062, 2),\n",
" (116, u'class_text', u'Iris-setosa', 0.036873452, 3),\n",
" (117, u'class_text', u'Iris-virginica', 0.5800313, 1),\n",
" (117, u'class_text', u'Iris-versicolor', 0.3929817, 2),\n",
" (117, u'class_text', u'Iris-setosa', 0.026986998, 3),\n",
" (119, u'class_text', u'Iris-virginica', 0.72622484, 1),\n",
" (119, u'class_text', u'Iris-versicolor', 0.26773784, 2),\n",
" (119, u'class_text', u'Iris-setosa', 0.0060372944, 3),\n",
" (127, u'class_text', u'Iris-virginica', 0.5089497, 1),\n",
" (127, u'class_text', u'Iris-versicolor', 0.44541556, 2),\n",
" (127, u'class_text', u'Iris-setosa', 0.045634773, 3),\n",
" (136, u'class_text', u'Iris-virginica', 0.62922823, 1),\n",
" (136, u'class_text', u'Iris-versicolor', 0.35819018, 2),\n",
" (136, u'class_text', u'Iris-setosa', 0.012581516, 3),\n",
" (144, u'class_text', u'Iris-virginica', 0.6017383, 1),\n",
" (144, u'class_text', u'Iris-versicolor', 0.3781529, 2),\n",
" (144, u'class_text', u'Iris-setosa', 0.020108894, 3),\n",
" (146, u'class_text', u'Iris-virginica', 0.5293082, 1),\n",
" (146, u'class_text', u'Iris-versicolor', 0.4390557, 2),\n",
" (146, u'class_text', u'Iris-setosa', 0.031636048, 3),\n",
" (147, u'class_text', u'Iris-virginica', 0.58249867, 1),\n",
" (147, u'class_text', u'Iris-versicolor', 0.39045528, 2),\n",
" (147, u'class_text', u'Iris-setosa', 0.027046034, 3)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict' -- output table\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id, rank;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2L,)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id)\n",
"WHERE iris_predict.class_value != iris_test.class_text AND iris_predict.rank = 1;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Percent missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>test_accuracy_percent</th>\n",
" </tr>\n",
" <tr>\n",
" <td>93.33</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('93.33'),)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from\n",
" (select iris_test.class_text as actual, iris_predict.class_value as estimated\n",
" from iris_predict inner join iris_test\n",
" on iris_test.id=iris_predict.id where iris_predict.rank = 1) q\n",
"WHERE q.actual=q.estimated;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred_byom\"></a>\n",
"# 7. Predict BYOM\n",
"The predict BYOM function allows you to do inference on models that have not been trained on MADlib, but rather imported from elsewhere. \n",
"\n",
"We will use the validation dataset for prediction as well, which is not usual but serves to show the syntax.\n",
"\n",
"See load_keras_model()\n",
"http://madlib.apache.org/docs/latest/group__grp__keras__model__arch.html\n",
"for details on how to load the model architecture and weights. In this example we will use weights we already have:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"UPDATE model_arch_library \n",
"SET model_weights = iris_model.model_weights \n",
"FROM iris_model \n",
"WHERE model_arch_library.model_id = 1;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now train using a model from the model architecture table directly without referencing the model table from the MADlib training. \n",
"\n",
"Note that if you specify the class values parameter as we do below, it must reflect how the dependent variable was 1-hot encoded for training. In this example the 'training_preprocessor_dl()' in Step 2 above encoded in the order {'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'} so this is the order we pass in the parameter. If we accidently picked another order that did not match the 1-hot encoding, the predictions would be wrong."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"30 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>class_name</th>\n",
" <th>class_value</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.83670896</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.8369735</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.87973696</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.93740743</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.83670896</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.8709096</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4681935</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.45466852</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.47486252</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47181308</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47956672</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.50861007</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.5061021</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.49345753</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.4796765</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.47809058</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6172143</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5837618</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.61951214</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5954762</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.571379</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.57040656</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.52341586</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5800313</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.72622484</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5089497</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.62922823</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6017383</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5293082</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>dependent_var</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.58249867</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(10, u'dependent_var', u'Iris-setosa', 0.83670896),\n",
" (13, u'dependent_var', u'Iris-setosa', 0.8369735),\n",
" (29, u'dependent_var', u'Iris-setosa', 0.87973696),\n",
" (34, u'dependent_var', u'Iris-setosa', 0.93740743),\n",
" (38, u'dependent_var', u'Iris-setosa', 0.83670896),\n",
" (43, u'dependent_var', u'Iris-setosa', 0.8709096),\n",
" (56, u'dependent_var', u'Iris-virginica', 0.4681935),\n",
" (61, u'dependent_var', u'Iris-versicolor', 0.45466852),\n",
" (64, u'dependent_var', u'Iris-virginica', 0.47486252),\n",
" (67, u'dependent_var', u'Iris-versicolor', 0.47181308),\n",
" (70, u'dependent_var', u'Iris-versicolor', 0.47956672),\n",
" (72, u'dependent_var', u'Iris-versicolor', 0.50861007),\n",
" (75, u'dependent_var', u'Iris-versicolor', 0.5061021),\n",
" (89, u'dependent_var', u'Iris-versicolor', 0.49345753),\n",
" (92, u'dependent_var', u'Iris-versicolor', 0.4796765),\n",
" (94, u'dependent_var', u'Iris-versicolor', 0.47809058),\n",
" (101, u'dependent_var', u'Iris-virginica', 0.6172143),\n",
" (102, u'dependent_var', u'Iris-virginica', 0.5837618),\n",
" (103, u'dependent_var', u'Iris-virginica', 0.61951214),\n",
" (112, u'dependent_var', u'Iris-virginica', 0.5954762),\n",
" (113, u'dependent_var', u'Iris-virginica', 0.571379),\n",
" (115, u'dependent_var', u'Iris-virginica', 0.57040656),\n",
" (116, u'dependent_var', u'Iris-virginica', 0.52341586),\n",
" (117, u'dependent_var', u'Iris-virginica', 0.5800313),\n",
" (119, u'dependent_var', u'Iris-virginica', 0.72622484),\n",
" (127, u'dependent_var', u'Iris-virginica', 0.5089497),\n",
" (136, u'dependent_var', u'Iris-virginica', 0.62922823),\n",
" (144, u'dependent_var', u'Iris-virginica', 0.6017383),\n",
" (146, u'dependent_var', u'Iris-virginica', 0.5293082),\n",
" (147, u'dependent_var', u'Iris-virginica', 0.58249867)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict_byom;\n",
"\n",
"SELECT madlib.madlib_keras_predict_byom('model_arch_library', -- model arch table\n",
" 1, -- model arch id\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict_byom', -- output table\n",
" 'response', -- prediction type\n",
" FALSE, -- use GPUs\n",
" ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']], -- class values\n",
" 1.0 -- normalizing const\n",
" );\n",
"\n",
"SELECT * FROM iris_predict_byom ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count missclassifications:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2L,)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM iris_predict_byom JOIN iris_test USING (id)\n",
"WHERE iris_predict_byom.class_value != iris_test.class_text;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Percent missclassifications:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>test_accuracy_percent</th>\n",
" </tr>\n",
" <tr>\n",
" <td>93.33</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('93.33'),)]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from\n",
" (select iris_test.class_text as actual, iris_predict_byom.class_value as estimated\n",
" from iris_predict_byom inner join iris_test\n",
" on iris_test.id=iris_predict_byom.id) q\n",
"WHERE q.actual=q.estimated;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"class2\"></a>\n",
"# Classification with Other Parameters\n",
"\n",
"<a id=\"val_dataset\"></a>\n",
"# 1. Validation dataset\n",
"Now use a validation dataset and compute metrics every 2nd iteration using the 'metrics_compute_frequency' parameter. This can help reduce run time if you do not need metrics computed at every iteration."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_model, iris_model_summary;\n",
"\n",
"SELECT madlib.madlib_keras_fit('iris_train_packed', -- source table\n",
" 'iris_model', -- model output table\n",
" 'model_arch_library', -- model arch table\n",
" 1, -- model arch id\n",
" $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params\n",
" $$ batch_size=5, epochs=3 $$, -- fit_params\n",
" 10, -- num_iterations\n",
" FALSE, -- use GPUs\n",
" 'iris_test_packed', -- validation dataset\n",
" 2, -- metrics compute frequency\n",
" FALSE, -- warm start\n",
" 'Sophie L.', -- name\n",
" 'Simple MLP for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>model</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>num_iterations</th>\n",
" <th>validation_table</th>\n",
" <th>object_table</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>metrics_iters</th>\n",
" <th>class_text_class_values</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_model</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>1</td>\n",
" <td> loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] </td>\n",
" <td> batch_size=5, epochs=3 </td>\n",
" <td>10</td>\n",
" <td>iris_test_packed</td>\n",
" <td>None</td>\n",
" <td>2</td>\n",
" <td>Sophie L.</td>\n",
" <td>Simple MLP for iris dataset</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>2021-03-06 00:27:42.910502</td>\n",
" <td>2021-03-06 00:27:44.171209</td>\n",
" <td>[0.706467866897583, 0.850914001464844, 0.988704919815063, 1.12321996688843, 1.26061987876892]</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[3]</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.966666638851</td>\n",
" <td>0.286152273417</td>\n",
" <td>[0.933333337306976, 0.941666662693024, 0.941666662693024, 0.958333313465118, 0.966666638851166]</td>\n",
" <td>[0.410510897636414, 0.371806919574738, 0.339208543300629, 0.310610443353653, 0.286152273416519]</td>\n",
" <td>1.0</td>\n",
" <td>0.312809795141</td>\n",
" <td>[1.0, 1.0, 1.0, 1.0, 1.0]</td>\n",
" <td>[0.478174388408661, 0.426770567893982, 0.391106754541397, 0.351149171590805, 0.31280979514122]</td>\n",
" <td>[2, 4, 6, 8, 10]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_model', [u'class_text'], [u'attributes'], u'model_arch_library', 1, u\" loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] \", u' batch_size=5, epochs=3 ', 10, u'iris_test_packed', None, 2, u'Sophie L.', u'Simple MLP for iris dataset', u'madlib_keras', 0.7900390625, datetime.datetime(2021, 3, 6, 0, 27, 42, 910502), datetime.datetime(2021, 3, 6, 0, 27, 44, 171209), [0.706467866897583, 0.850914001464844, 0.988704919815063, 1.12321996688843, 1.26061987876892], u'1.18.0-dev', [3], [u'character varying'], 1.0, [u'accuracy'], u'categorical_crossentropy', 0.966666638851166, 0.286152273416519, [0.933333337306976, 0.941666662693024, 0.941666662693024, 0.958333313465118, 0.966666638851166], [0.410510897636414, 0.371806919574738, 0.339208543300629, 0.310610443353653, 0.286152273416519], 1.0, 0.31280979514122, [1.0, 1.0, 1.0, 1.0, 1.0], [0.478174388408661, 0.426770567893982, 0.391106754541397, 0.351149171590805, 0.31280979514122], [2, 4, 6, 8, 10], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'])]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accuracy by iteration"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# get accuracy and iteration number\n",
"iters_proxy = %sql SELECT metrics_iters FROM iris_model_summary;\n",
"train_accuracy_proxy = %sql SELECT training_metrics FROM iris_model_summary;\n",
"test_accuracy_proxy = %sql SELECT validation_metrics FROM iris_model_summary;\n",
"\n",
"# get number of points\n",
"num_points_proxy = %sql SELECT array_length(metrics_iters,1) FROM iris_model_summary;\n",
"num_points = num_points_proxy[0]\n",
"\n",
"# reshape to np arrays\n",
"iters = np.array(iters_proxy).reshape(num_points)\n",
"train_accuracy = np.array(train_accuracy_proxy).reshape(num_points)\n",
"test_accuracy = np.array(test_accuracy_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation accuracy by iteration')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(iters, test_accuracy, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loss by iteration"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get loss\n",
"train_loss_proxy = %sql SELECT training_loss FROM iris_model_summary;\n",
"test_loss_proxy = %sql SELECT validation_loss FROM iris_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"train_loss = np.array(train_loss_proxy).reshape(num_points)\n",
"test_loss = np.array(test_loss_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation loss by iteration')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Loss')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_loss, 'g.-', label='Train')\n",
"plt.plot(iters, test_loss, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accuracy by time"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get time\n",
"time_proxy = %sql SELECT metrics_elapsed_time FROM iris_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"time = np.array(time_proxy).reshape(num_points)/60.0\n",
"\n",
"#plot\n",
"plt.title('Iris validation accuracy by time')\n",
"plt.xlabel('Time (min)')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(time, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(time, test_accuracy, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to achieve a given accuracy"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"#plot\n",
"plt.title('Iris time by validation accuracy')\n",
"plt.xlabel('Accuracy')\n",
"plt.ylabel('Time (min)')\n",
"plt.grid(True)\n",
"plt.plot(train_accuracy, time, 'g.-', label='Train')\n",
"plt.plot(test_accuracy, time, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"pred_prob\"></a>\n",
"# 2. Predict probabilities\n",
"Predict with probabilities for each class:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"90 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>class_name</th>\n",
" <th>class_value</th>\n",
" <th>prob</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.95964456</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.040107667</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.00024777954</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9597473</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.039995104</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.00025748153</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9761629</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.02375229</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>8.479464e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.99167526</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.008311711</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>1.30546e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.95964456</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.040107667</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.00024777954</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.9731986</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.026633002</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.0001682964</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.5465453</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.44023818</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.01321647</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.58524394</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.38831925</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.026436739</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.54376984</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4466499</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.009580206</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.57931274</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.4024986</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.018188644</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.6546029</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.31436205</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.031035094</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.7274642</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.23212723</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.04040851</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.70608836</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.2678585</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.026053142</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.7004706</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.24725808</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.052271266</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.60884434</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.37558457</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.015571073</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.6937521</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.22628777</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.07996013</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.8121722</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.18745965</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>101</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00036821415</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.732954</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.26589376</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>102</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0011522635</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7756057</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.22408752</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>103</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00030681648</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.73672587</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.26255292</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>112</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00072121713</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6946963</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.30448684</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>113</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0008168273</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7407642</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.2582139</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0010218392</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.6248408</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.37318283</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>116</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0019763925</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.69742286</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.30152085</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>117</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0010562533</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.8924382</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.10752802</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>119</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>3.379417e-05</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.5494711</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.44651195</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>127</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.004016916</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.7880493</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.21177953</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>136</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0001712008</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.76935935</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.23023891</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>144</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.00040176214</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.62273574</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.37572634</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>146</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0015378947</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-virginica</td>\n",
" <td>0.71288556</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-versicolor</td>\n",
" <td>0.2861904</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>147</td>\n",
" <td>class_text</td>\n",
" <td>Iris-setosa</td>\n",
" <td>0.0009240419</td>\n",
" <td>3</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(10, u'class_text', u'Iris-setosa', 0.95964456, 1),\n",
" (10, u'class_text', u'Iris-versicolor', 0.040107667, 2),\n",
" (10, u'class_text', u'Iris-virginica', 0.00024777954, 3),\n",
" (13, u'class_text', u'Iris-setosa', 0.9597473, 1),\n",
" (13, u'class_text', u'Iris-versicolor', 0.039995104, 2),\n",
" (13, u'class_text', u'Iris-virginica', 0.00025748153, 3),\n",
" (29, u'class_text', u'Iris-setosa', 0.9761629, 1),\n",
" (29, u'class_text', u'Iris-versicolor', 0.02375229, 2),\n",
" (29, u'class_text', u'Iris-virginica', 8.479464e-05, 3),\n",
" (34, u'class_text', u'Iris-setosa', 0.99167526, 1),\n",
" (34, u'class_text', u'Iris-versicolor', 0.008311711, 2),\n",
" (34, u'class_text', u'Iris-virginica', 1.30546e-05, 3),\n",
" (38, u'class_text', u'Iris-setosa', 0.95964456, 1),\n",
" (38, u'class_text', u'Iris-versicolor', 0.040107667, 2),\n",
" (38, u'class_text', u'Iris-virginica', 0.00024777954, 3),\n",
" (43, u'class_text', u'Iris-setosa', 0.9731986, 1),\n",
" (43, u'class_text', u'Iris-versicolor', 0.026633002, 2),\n",
" (43, u'class_text', u'Iris-virginica', 0.0001682964, 3),\n",
" (56, u'class_text', u'Iris-versicolor', 0.5465453, 1),\n",
" (56, u'class_text', u'Iris-virginica', 0.44023818, 2),\n",
" (56, u'class_text', u'Iris-setosa', 0.01321647, 3),\n",
" (61, u'class_text', u'Iris-versicolor', 0.58524394, 1),\n",
" (61, u'class_text', u'Iris-virginica', 0.38831925, 2),\n",
" (61, u'class_text', u'Iris-setosa', 0.026436739, 3),\n",
" (64, u'class_text', u'Iris-versicolor', 0.54376984, 1),\n",
" (64, u'class_text', u'Iris-virginica', 0.4466499, 2),\n",
" (64, u'class_text', u'Iris-setosa', 0.009580206, 3),\n",
" (67, u'class_text', u'Iris-versicolor', 0.57931274, 1),\n",
" (67, u'class_text', u'Iris-virginica', 0.4024986, 2),\n",
" (67, u'class_text', u'Iris-setosa', 0.018188644, 3),\n",
" (70, u'class_text', u'Iris-versicolor', 0.6546029, 1),\n",
" (70, u'class_text', u'Iris-virginica', 0.31436205, 2),\n",
" (70, u'class_text', u'Iris-setosa', 0.031035094, 3),\n",
" (72, u'class_text', u'Iris-versicolor', 0.7274642, 1),\n",
" (72, u'class_text', u'Iris-virginica', 0.23212723, 2),\n",
" (72, u'class_text', u'Iris-setosa', 0.04040851, 3),\n",
" (75, u'class_text', u'Iris-versicolor', 0.70608836, 1),\n",
" (75, u'class_text', u'Iris-virginica', 0.2678585, 2),\n",
" (75, u'class_text', u'Iris-setosa', 0.026053142, 3),\n",
" (89, u'class_text', u'Iris-versicolor', 0.7004706, 1),\n",
" (89, u'class_text', u'Iris-virginica', 0.24725808, 2),\n",
" (89, u'class_text', u'Iris-setosa', 0.052271266, 3),\n",
" (92, u'class_text', u'Iris-versicolor', 0.60884434, 1),\n",
" (92, u'class_text', u'Iris-virginica', 0.37558457, 2),\n",
" (92, u'class_text', u'Iris-setosa', 0.015571073, 3),\n",
" (94, u'class_text', u'Iris-versicolor', 0.6937521, 1),\n",
" (94, u'class_text', u'Iris-virginica', 0.22628777, 2),\n",
" (94, u'class_text', u'Iris-setosa', 0.07996013, 3),\n",
" (101, u'class_text', u'Iris-virginica', 0.8121722, 1),\n",
" (101, u'class_text', u'Iris-versicolor', 0.18745965, 2),\n",
" (101, u'class_text', u'Iris-setosa', 0.00036821415, 3),\n",
" (102, u'class_text', u'Iris-virginica', 0.732954, 1),\n",
" (102, u'class_text', u'Iris-versicolor', 0.26589376, 2),\n",
" (102, u'class_text', u'Iris-setosa', 0.0011522635, 3),\n",
" (103, u'class_text', u'Iris-virginica', 0.7756057, 1),\n",
" (103, u'class_text', u'Iris-versicolor', 0.22408752, 2),\n",
" (103, u'class_text', u'Iris-setosa', 0.00030681648, 3),\n",
" (112, u'class_text', u'Iris-virginica', 0.73672587, 1),\n",
" (112, u'class_text', u'Iris-versicolor', 0.26255292, 2),\n",
" (112, u'class_text', u'Iris-setosa', 0.00072121713, 3),\n",
" (113, u'class_text', u'Iris-virginica', 0.6946963, 1),\n",
" (113, u'class_text', u'Iris-versicolor', 0.30448684, 2),\n",
" (113, u'class_text', u'Iris-setosa', 0.0008168273, 3),\n",
" (115, u'class_text', u'Iris-virginica', 0.7407642, 1),\n",
" (115, u'class_text', u'Iris-versicolor', 0.2582139, 2),\n",
" (115, u'class_text', u'Iris-setosa', 0.0010218392, 3),\n",
" (116, u'class_text', u'Iris-virginica', 0.6248408, 1),\n",
" (116, u'class_text', u'Iris-versicolor', 0.37318283, 2),\n",
" (116, u'class_text', u'Iris-setosa', 0.0019763925, 3),\n",
" (117, u'class_text', u'Iris-virginica', 0.69742286, 1),\n",
" (117, u'class_text', u'Iris-versicolor', 0.30152085, 2),\n",
" (117, u'class_text', u'Iris-setosa', 0.0010562533, 3),\n",
" (119, u'class_text', u'Iris-virginica', 0.8924382, 1),\n",
" (119, u'class_text', u'Iris-versicolor', 0.10752802, 2),\n",
" (119, u'class_text', u'Iris-setosa', 3.379417e-05, 3),\n",
" (127, u'class_text', u'Iris-virginica', 0.5494711, 1),\n",
" (127, u'class_text', u'Iris-versicolor', 0.44651195, 2),\n",
" (127, u'class_text', u'Iris-setosa', 0.004016916, 3),\n",
" (136, u'class_text', u'Iris-virginica', 0.7880493, 1),\n",
" (136, u'class_text', u'Iris-versicolor', 0.21177953, 2),\n",
" (136, u'class_text', u'Iris-setosa', 0.0001712008, 3),\n",
" (144, u'class_text', u'Iris-virginica', 0.76935935, 1),\n",
" (144, u'class_text', u'Iris-versicolor', 0.23023891, 2),\n",
" (144, u'class_text', u'Iris-setosa', 0.00040176214, 3),\n",
" (146, u'class_text', u'Iris-virginica', 0.62273574, 1),\n",
" (146, u'class_text', u'Iris-versicolor', 0.37572634, 2),\n",
" (146, u'class_text', u'Iris-setosa', 0.0015378947, 3),\n",
" (147, u'class_text', u'Iris-virginica', 0.71288556, 1),\n",
" (147, u'class_text', u'Iris-versicolor', 0.2861904, 2),\n",
" (147, u'class_text', u'Iris-setosa', 0.0009240419, 3)]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_predict;\n",
"\n",
"SELECT madlib.madlib_keras_predict('iris_model', -- model\n",
" 'iris_test', -- test_table\n",
" 'id', -- id column\n",
" 'attributes', -- independent var\n",
" 'iris_predict', -- output table\n",
" 'prob' -- response type\n",
" );\n",
"\n",
"SELECT * FROM iris_predict ORDER BY id, rank;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"warm_start\"></a>\n",
"# 3. Warm start\n",
"Next, use the warm_start parameter to continue learning, using the coefficients from the run above. Note that we don't drop the model table or model summary table:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT madlib.madlib_keras_fit('iris_train_packed', -- source table\n",
" 'iris_model', -- model output table\n",
" 'model_arch_library', -- model arch table\n",
" 1, -- model arch id\n",
" $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params\n",
" $$ batch_size=5, epochs=3 $$, -- fit_params\n",
" 10, -- num_iterations\n",
" FALSE, -- use GPUs\n",
" 'iris_test_packed', -- validation dataset\n",
" 2, -- metrics compute frequency\n",
" TRUE, -- warm start\n",
" 'Sophie L.', -- name \n",
" 'Simple MLP for iris dataset' -- description\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the summary table and plots below note that the loss and accuracy values pick up from where the previous run left off:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>model</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>num_iterations</th>\n",
" <th>validation_table</th>\n",
" <th>object_table</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>metrics_iters</th>\n",
" <th>class_text_class_values</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_model</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>1</td>\n",
" <td> loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] </td>\n",
" <td> batch_size=5, epochs=3 </td>\n",
" <td>10</td>\n",
" <td>iris_test_packed</td>\n",
" <td>None</td>\n",
" <td>2</td>\n",
" <td>Sophie L.</td>\n",
" <td>Simple MLP for iris dataset</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>2021-03-06 00:27:51.102558</td>\n",
" <td>2021-03-06 00:27:52.451185</td>\n",
" <td>[0.781347990036011, 0.923561096191406, 1.06405401229858, 1.20302820205688, 1.34854102134705]</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[3]</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.975000023842</td>\n",
" <td>0.194891035557</td>\n",
" <td>[0.941666662693024, 0.966666638851166, 0.958333313465118, 0.949999988079071, 0.975000023841858]</td>\n",
" <td>[0.262409120798111, 0.24169448018074, 0.222953796386719, 0.207046672701836, 0.194891035556793]</td>\n",
" <td>1.0</td>\n",
" <td>0.188044458628</td>\n",
" <td>[0.966666638851166, 1.0, 1.0, 1.0, 1.0]</td>\n",
" <td>[0.293483078479767, 0.254781544208527, 0.232207864522934, 0.212682083249092, 0.188044458627701]</td>\n",
" <td>[2, 4, 6, 8, 10]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_model', [u'class_text'], [u'attributes'], u'model_arch_library', 1, u\" loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] \", u' batch_size=5, epochs=3 ', 10, u'iris_test_packed', None, 2, u'Sophie L.', u'Simple MLP for iris dataset', u'madlib_keras', 0.7900390625, datetime.datetime(2021, 3, 6, 0, 27, 51, 102558), datetime.datetime(2021, 3, 6, 0, 27, 52, 451185), [0.781347990036011, 0.923561096191406, 1.06405401229858, 1.20302820205688, 1.34854102134705], u'1.18.0-dev', [3], [u'character varying'], 1.0, [u'accuracy'], u'categorical_crossentropy', 0.975000023841858, 0.194891035556793, [0.941666662693024, 0.966666638851166, 0.958333313465118, 0.949999988079071, 0.975000023841858], [0.262409120798111, 0.24169448018074, 0.222953796386719, 0.207046672701836, 0.194891035556793], 1.0, 0.188044458627701, [0.966666638851166, 1.0, 1.0, 1.0, 1.0], [0.293483078479767, 0.254781544208527, 0.232207864522934, 0.212682083249092, 0.188044458627701], [2, 4, 6, 8, 10], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'])]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_model_summary;"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# get accuracy and iteration number\n",
"iters_proxy = %sql SELECT metrics_iters FROM iris_model_summary;\n",
"train_accuracy_proxy = %sql SELECT training_metrics FROM iris_model_summary;\n",
"test_accuracy_proxy = %sql SELECT validation_metrics FROM iris_model_summary;\n",
"\n",
"# get number of points\n",
"num_points_proxy = %sql SELECT array_length(metrics_iters,1) FROM iris_model_summary;\n",
"num_points = num_points_proxy[0]\n",
"\n",
"# reshape to np arrays\n",
"iters = np.array(iters_proxy).reshape(num_points)\n",
"train_accuracy = np.array(train_accuracy_proxy).reshape(num_points)\n",
"test_accuracy = np.array(test_accuracy_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation accuracy by iteration - warm start')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(iters, test_accuracy, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get loss\n",
"train_loss_proxy = %sql SELECT training_loss FROM iris_model_summary;\n",
"test_loss_proxy = %sql SELECT validation_loss FROM iris_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"train_loss = np.array(train_loss_proxy).reshape(num_points)\n",
"test_loss = np.array(test_loss_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation loss by iteration - warm start')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Loss')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_loss, 'g.-', label='Train')\n",
"plt.plot(iters, test_loss, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"transfer_learn\"></a>\n",
"# Transfer learning\n",
"\n",
"<a id=\"load2\"></a>\n",
"# 1. Define and load model architecture with some layers frozen\n",
"Here we want to start with initial weights from a pre-trained model rather than training from scratch. We also want to use a model architecture with the earlier feature layer(s) frozen to save on training time. The example below is somewhat contrived but gives you the idea of the steps.\n",
"\n",
"First define a model architecture with the 1st hidden layer frozen:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_3 (Dense) (None, 10) 50 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 10) 110 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 3) 33 \n",
"=================================================================\n",
"Total params: 193\n",
"Trainable params: 143\n",
"Non-trainable params: 50\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model_transfer = Sequential()\n",
"model_transfer.add(Dense(10, activation='relu', input_shape=(4,), trainable=False))\n",
"model_transfer.add(Dense(10, activation='relu'))\n",
"model_transfer.add(Dense(3, activation='softmax'))\n",
" \n",
"model_transfer.summary()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.2.4-tf\", \"config\": {\"layers\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": false, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"dtype\": \"float32\", \"seed\": null}}, \"name\": \"dense_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {\"dtype\": \"float32\"}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"name\": \"sequential_1\"}, \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_transfer.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load transfer model into model architecture table"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>model_arch</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>Sophie</td>\n",
" <td>A simple model</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': False, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_4', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}</td>\n",
" <td>Maria</td>\n",
" <td>A transfer model</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u' ... (1340 characters truncated) ... s_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, u'Sophie', u'A simple model'),\n",
" (2, {u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u' ... (1341 characters truncated) ... s_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}, u'Maria', u'A transfer model')]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT madlib.load_keras_model('model_arch_library', -- Output table, \n",
"$$\n",
"{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": false, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"batch_input_shape\": [null, 4], \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 3, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}\n",
"$$\n",
"::json, -- JSON blob\n",
" NULL, -- Weights\n",
" 'Maria', -- Name\n",
" 'A transfer model' -- Descr\n",
");\n",
"\n",
"SELECT model_id, model_arch, name, description FROM model_arch_library ORDER BY model_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"train2\"></a>\n",
"# 2. Train transfer model\n",
"\n",
"Fetch the weights from a previous MADlib run. (Normally these would be downloaded from a source that trained the same model architecture on a related dataset.)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"UPDATE model_arch_library \n",
"SET model_weights = iris_model.model_weights \n",
"FROM iris_model \n",
"WHERE model_arch_library.model_id = 2;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now train the model using the transfer model and the pre-trained weights:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>madlib_keras_fit</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS iris_model, iris_model_summary;\n",
"\n",
"SELECT madlib.madlib_keras_fit('iris_train_packed', -- source table\n",
" 'iris_model', -- model output table\n",
" 'model_arch_library', -- model arch table\n",
" 2, -- model arch id\n",
" $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params\n",
" $$ batch_size=5, epochs=3 $$, -- fit_params\n",
" 10, -- num_iterations\n",
" FALSE, -- use GPUs\n",
" 'iris_test_packed', -- validation dataset\n",
" 2 -- metrics compute frequency\n",
" );"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>model</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>num_iterations</th>\n",
" <th>validation_table</th>\n",
" <th>object_table</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_type</th>\n",
" <th>loss_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>metrics_iters</th>\n",
" <th>class_text_class_values</th>\n",
" </tr>\n",
" <tr>\n",
" <td>iris_train_packed</td>\n",
" <td>iris_model</td>\n",
" <td>[u'class_text']</td>\n",
" <td>[u'attributes']</td>\n",
" <td>model_arch_library</td>\n",
" <td>2</td>\n",
" <td> loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] </td>\n",
" <td> batch_size=5, epochs=3 </td>\n",
" <td>10</td>\n",
" <td>iris_test_packed</td>\n",
" <td>None</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>madlib_keras</td>\n",
" <td>0.7900390625</td>\n",
" <td>2021-03-06 00:27:56.293667</td>\n",
" <td>2021-03-06 00:27:57.661243</td>\n",
" <td>[0.832237005233765, 0.965812921524048, 1.09816098213196, 1.22954201698303, 1.3674840927124]</td>\n",
" <td>1.18.0-dev</td>\n",
" <td>[3]</td>\n",
" <td>[u'character varying']</td>\n",
" <td>1.0</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>categorical_crossentropy</td>\n",
" <td>0.949999988079</td>\n",
" <td>0.153273612261</td>\n",
" <td>[0.949999988079071, 0.958333313465118, 0.949999988079071, 0.949999988079071, 0.949999988079071]</td>\n",
" <td>[0.182110622525215, 0.173247531056404, 0.165094882249832, 0.158673033118248, 0.153273612260818]</td>\n",
" <td>1.0</td>\n",
" <td>0.134765788913</td>\n",
" <td>[1.0, 1.0, 1.0, 1.0, 1.0]</td>\n",
" <td>[0.177851542830467, 0.161254957318306, 0.152191400527954, 0.142795532941818, 0.134765788912773]</td>\n",
" <td>[2, 4, 6, 8, 10]</td>\n",
" <td>[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'iris_train_packed', u'iris_model', [u'class_text'], [u'attributes'], u'model_arch_library', 2, u\" loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] \", u' batch_size=5, epochs=3 ', 10, u'iris_test_packed', None, 2, None, None, u'madlib_keras', 0.7900390625, datetime.datetime(2021, 3, 6, 0, 27, 56, 293667), datetime.datetime(2021, 3, 6, 0, 27, 57, 661243), [0.832237005233765, 0.965812921524048, 1.09816098213196, 1.22954201698303, 1.3674840927124], u'1.18.0-dev', [3], [u'character varying'], 1.0, [u'accuracy'], u'categorical_crossentropy', 0.949999988079071, 0.153273612260818, [0.949999988079071, 0.958333313465118, 0.949999988079071, 0.949999988079071, 0.949999988079071], [0.182110622525215, 0.173247531056404, 0.165094882249832, 0.158673033118248, 0.153273612260818], 1.0, 0.134765788912773, [1.0, 1.0, 1.0, 1.0, 1.0], [0.177851542830467, 0.161254957318306, 0.152191400527954, 0.142795532941818, 0.134765788912773], [2, 4, 6, 8, 10], [u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica'])]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM iris_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note loss picks up from where the last training left off:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# get accuracy and iteration number\n",
"iters_proxy = %sql SELECT metrics_iters FROM iris_model_summary;\n",
"train_accuracy_proxy = %sql SELECT training_metrics FROM iris_model_summary;\n",
"test_accuracy_proxy = %sql SELECT validation_metrics FROM iris_model_summary;\n",
"\n",
"# get number of points\n",
"num_points_proxy = %sql SELECT array_length(metrics_iters,1) FROM iris_model_summary;\n",
"num_points = num_points_proxy[0]\n",
"\n",
"# reshape to np arrays\n",
"iters = np.array(iters_proxy).reshape(num_points)\n",
"train_accuracy = np.array(train_accuracy_proxy).reshape(num_points)\n",
"test_accuracy = np.array(test_accuracy_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation accuracy by iteration - transfer learn')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(iters, test_accuracy, 'r.-', label='Test')\n",
"plt.legend();"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get loss\n",
"train_loss_proxy = %sql SELECT training_loss FROM iris_model_summary;\n",
"test_loss_proxy = %sql SELECT validation_loss FROM iris_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"train_loss = np.array(train_loss_proxy).reshape(num_points)\n",
"test_loss = np.array(test_loss_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('Iris validation loss by iteration - transfer learn')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Loss')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_loss, 'g.-', label='Train')\n",
"plt.plot(iters, test_loss, 'r.-', label='Test')\n",
"plt.legend();"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 1
}