blob: f2d7a34d215a04d6210df07763bac5bce13ec09b [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"https://s3.amazonaws.com/greenplum.org/wp-content/uploads/2018/11/14180216/logo-gpdb-light.svg\" alt=\"drawing\" width=\"200\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Science Workshop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook serves as an hands-on introduction to the data science pipeline. Using a single dataset throughout, it begins with loading the data into Greenplum Database (GPDB), then proceeds to data exploration, feature engineering, model development, and model evaluation.\n",
"\n",
"We’ll be using the publicly available [Abalone dataset from the University of California, Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/abalone). The dataset contains nine attributes (including our target prediction column).\n",
"\n",
"| Column Name | Data Type | Description|\n",
"| ---|:---:| ---:|\n",
"|Sex | text | M,F,I[infant]|\n",
"| Length | float | Longest shell measurement|\n",
"|Diameter | float | Perpendicular to length|\n",
"| Height | float | With meat in shell |\n",
"| Whole weight | float | Whole abalone |\n",
"| Shucked weight | float | Weight of meat only |\n",
"| Viscera weight | float | Gut weight (after bleeding) |\n",
"| Shell weight | float | Post-drying |\n",
"| Rings | integer | +1.5 gives the age in years|\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table of contents\n",
"\n",
"<a href=\"#setup\">1. Setup</a>\n",
"\n",
"<a href=\"#load_data\">2. Load data</a>\n",
"\n",
"<a href=\"#data_review\">3. Review raw data</a>\n",
"\n",
"<a href=\"#explore\">4. Explore data</a>\n",
"\n",
"<a href=\"#classification\">5. Classification models</a>\n",
"\n",
"* <a href=\"#logistic\">5a. Logistic regression</a>\n",
"\n",
"* <a href=\"#forest\">5b. Random forest</a>\n",
"\n",
"<a href=\"#regression\">6. Regression models</a>\n",
"\n",
"* <a href=\"#logistic\">6a. Linear regression</a>\n",
"\n",
"* <a href=\"#elastic\">6b. Elastic net</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"setup\"></a>\n",
"# 1. Set Up Your Notebook Environment and Connect to the Database"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# this command allows for visualizations to appear in the notebook\n",
"%matplotlib inline\n",
"%load_ext sql\n",
"import math\n",
"import six\n",
"import pandas as pd\n",
"from sqlalchemy import create_engine\n",
"import numpy as np \n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"pd.set_option('display.max_columns', 200)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Connect to Database"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Greenplum Database 5.x on GCP (PM demo machine) - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load_data\"></a>\n",
"# 2. Load Abalone Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An [abalone](https://simple.wikipedia.org/wiki/Abalone) is a salt water univalve mollusc.\n",
"We'll load the data from ```web external table``` then start looking at the data."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" <tr>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>0.514</td>\n",
" <td>0.2245</td>\n",
" <td>0.101</td>\n",
" <td>0.15</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>0.2255</td>\n",
" <td>0.0995</td>\n",
" <td>0.0485</td>\n",
" <td>0.07</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>0.677</td>\n",
" <td>0.2565</td>\n",
" <td>0.1415</td>\n",
" <td>0.21</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>0.516</td>\n",
" <td>0.2155</td>\n",
" <td>0.114</td>\n",
" <td>0.155</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>0.205</td>\n",
" <td>0.0895</td>\n",
" <td>0.0395</td>\n",
" <td>0.055</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>0.3515</td>\n",
" <td>0.141</td>\n",
" <td>0.0775</td>\n",
" <td>0.12</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>0.7775</td>\n",
" <td>0.237</td>\n",
" <td>0.1415</td>\n",
" <td>0.33</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>0.768</td>\n",
" <td>0.294</td>\n",
" <td>0.1495</td>\n",
" <td>0.26</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>0.5095</td>\n",
" <td>0.2165</td>\n",
" <td>0.1125</td>\n",
" <td>0.165</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>0.8945</td>\n",
" <td>0.3145</td>\n",
" <td>0.151</td>\n",
" <td>0.32</td>\n",
" <td>19</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'M', 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15, 15),\n",
" (u'M', 0.35, 0.265, 0.09, 0.2255, 0.0995, 0.0485, 0.07, 7),\n",
" (u'F', 0.53, 0.42, 0.135, 0.677, 0.2565, 0.1415, 0.21, 9),\n",
" (u'M', 0.44, 0.365, 0.125, 0.516, 0.2155, 0.114, 0.155, 10),\n",
" (u'I', 0.33, 0.255, 0.08, 0.205, 0.0895, 0.0395, 0.055, 7),\n",
" (u'I', 0.425, 0.3, 0.095, 0.3515, 0.141, 0.0775, 0.12, 8),\n",
" (u'F', 0.53, 0.415, 0.15, 0.7775, 0.237, 0.1415, 0.33, 20),\n",
" (u'F', 0.545, 0.425, 0.125, 0.768, 0.294, 0.1495, 0.26, 16),\n",
" (u'M', 0.475, 0.37, 0.125, 0.5095, 0.2165, 0.1125, 0.165, 9),\n",
" (u'F', 0.55, 0.44, 0.15, 0.8945, 0.3145, 0.151, 0.32, 19)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"---\n",
"--- The below code will create an external table that points at the data available at IMS/UCI.\n",
"---\n",
"drop external table if exists abalone_web;\n",
"create external web table\n",
" abalone_web (\n",
" sex char(1),\n",
" length float,\n",
" diameter float,\n",
" height float,\n",
" whole_weight float,\n",
" shucked_weight float,\n",
" viscera_weight float,\n",
" shell_weight float,\n",
" rings integer\n",
" )\n",
" location ( 'http://archive.ics.uci.edu/ml/machine-learning-databases/abalone/abalone.data' )\n",
" format 'TEXT' ( delimiter ',' null '' )\n",
" log errors\n",
" segment reject limit 100\n",
";\n",
"select * from abalone_web limit 10\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create heap table"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"4177 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>0.514</td>\n",
" <td>0.2245</td>\n",
" <td>0.101</td>\n",
" <td>0.15</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>0.2255</td>\n",
" <td>0.0995</td>\n",
" <td>0.0485</td>\n",
" <td>0.07</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>0.677</td>\n",
" <td>0.2565</td>\n",
" <td>0.1415</td>\n",
" <td>0.21</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>0.516</td>\n",
" <td>0.2155</td>\n",
" <td>0.114</td>\n",
" <td>0.155</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>0.205</td>\n",
" <td>0.0895</td>\n",
" <td>0.0395</td>\n",
" <td>0.055</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>0.3515</td>\n",
" <td>0.141</td>\n",
" <td>0.0775</td>\n",
" <td>0.12</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>0.7775</td>\n",
" <td>0.237</td>\n",
" <td>0.1415</td>\n",
" <td>0.33</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>0.768</td>\n",
" <td>0.294</td>\n",
" <td>0.1495</td>\n",
" <td>0.26</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>0.5095</td>\n",
" <td>0.2165</td>\n",
" <td>0.1125</td>\n",
" <td>0.165</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>0.8945</td>\n",
" <td>0.3145</td>\n",
" <td>0.151</td>\n",
" <td>0.32</td>\n",
" <td>19</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0, u'M', 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15, 15),\n",
" (1, u'M', 0.35, 0.265, 0.09, 0.2255, 0.0995, 0.0485, 0.07, 7),\n",
" (2, u'F', 0.53, 0.42, 0.135, 0.677, 0.2565, 0.1415, 0.21, 9),\n",
" (3, u'M', 0.44, 0.365, 0.125, 0.516, 0.2155, 0.114, 0.155, 10),\n",
" (4, u'I', 0.33, 0.255, 0.08, 0.205, 0.0895, 0.0395, 0.055, 7),\n",
" (5, u'I', 0.425, 0.3, 0.095, 0.3515, 0.141, 0.0775, 0.12, 8),\n",
" (6, u'F', 0.53, 0.415, 0.15, 0.7775, 0.237, 0.1415, 0.33, 20),\n",
" (7, u'F', 0.545, 0.425, 0.125, 0.768, 0.294, 0.1495, 0.26, 16),\n",
" (8, u'M', 0.475, 0.37, 0.125, 0.5095, 0.2165, 0.1125, 0.165, 9),\n",
" (9, u'F', 0.55, 0.44, 0.15, 0.8945, 0.3145, 0.151, 0.32, 19)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone;\n",
"CREATE TABLE abalone (\n",
" id integer,\n",
" sex char(1),\n",
" length float,\n",
" diameter float,\n",
" height float,\n",
" whole_weight float,\n",
" shucked_weight float,\n",
" viscera_weight float,\n",
" shell_weight float,\n",
" rings integer\n",
" )\n",
";\n",
" \n",
"DROP SEQUENCE IF EXISTS abalone_id ;\n",
"CREATE TEMPORARY SEQUENCE abalone_id MINVALUE 0 START 0;\n",
"\n",
"insert into \n",
" abalone \n",
"select \n",
" nextval('abalone_id') as id, *\n",
"FROM \n",
" abalone_web\n",
";\n",
"\n",
"\n",
"SELECT * FROM abalone ORDER BY id LIMIT 10;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"data_review\"></a>\n",
"# 3. Review raw data\n",
"\n",
"Load some data into a local variable ```abalone```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4177 rows affected.\n"
]
}
],
"source": [
"abalone = %sql select * from abalone where random() < 1.0 order by id limit 5000;\n",
"abalone = abalone.DataFrame()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 4177 entries, 0 to 4176\n",
"Data columns (total 10 columns):\n",
"id 4177 non-null int64\n",
"sex 4177 non-null object\n",
"length 4177 non-null float64\n",
"diameter 4177 non-null float64\n",
"height 4177 non-null float64\n",
"whole_weight 4177 non-null float64\n",
"shucked_weight 4177 non-null float64\n",
"viscera_weight 4177 non-null float64\n",
"shell_weight 4177 non-null float64\n",
"rings 4177 non-null int64\n",
"dtypes: float64(7), int64(2), object(1)\n",
"memory usage: 326.4+ KB\n"
]
}
],
"source": [
"abalone.info();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're interested in estimating the age of the abalone in the data. To get age, add 1.5 to the number of rings. A good place to begin is to create a histogram of the target variable."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi41LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvSM8oowAAEwdJREFUeJzt3X+MZeV93/H3p2BsAi0Lxhqh3W2XNqtErjdNnBEmchQNpnUAR1kiORSLxotLta2EU6esFK/TP3BTWSJtiGtLKdXG0Kwl12uCnbIKpA7CjFz/ATHYhOVHUjZkCbta2Dj8cMZOYk387R/3wRrvzrDLvXfm7p3n/ZJGc85znnPv9+Fc9nPPc8+5k6pCktSfvzfpAiRJk2EASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjp15qQLeC0XXnhhbdmyZdJlnJJvfetbnHPOOZMuY+zW67hg/Y7NcU2fcY/tkUce+UZVveVk/U7rANiyZQsPP/zwpMs4JfPz88zNzU26jLFbr+OC9Ts2xzV9xj22JM+eSj+ngCSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVOn9Z3AWltbdt9zQtuubYtcv0z78Q7d8p7VKEnSKvIMQJI6ZQBIUqcMAEnqlAEgSZ06aQAkuSPJsSSPL2n7r0n+OMljSX43yYYl2z6S5GCSP0ny00var2htB5PsHv9QJEmvx6mcAfw2cMVxbfcBb6uqHwH+H/ARgCRvBa4F/mnb578nOSPJGcBvAlcCbwXe1/pKkibkpAFQVV8GXjyu7Q+qarGtPghsasvbgX1V9bdV9WfAQeCS9nOwqp6pqu8A+1pfSdKEjOMzgH8N/H5b3gg8t2Tb4da2UrskaUJGuhEsyX8EFoHPjKccSLIT2AkwMzPD/Pz8uB56VS0sLExNrSvZtW3xhLaZs5dvP940jn09HLPlOK7pM6mxDR0ASa4Hfga4vKqqNR8BNi/ptqm18Rrt36eq9gB7AGZnZ2ta/gboevh7pcvd8btr2yK3Hjj5y+TQdXOrUNHqWg/HbDmOa/pMamxDTQEluQL4ZeBnq+rbSzbtB65N8sYkFwNbgT8EvgpsTXJxkrMYfFC8f7TSJUmjOOlbuySfBeaAC5McBm5mcNXPG4H7kgA8WFX/rqqeSHIn8CSDqaEbq+rv2uN8EPgicAZwR1U9sQrjkSSdopMGQFW9b5nm21+j/8eAjy3Tfi9w7+uqTpK0arwTWJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1KmTBkCSO5IcS/L4krYLktyX5On2+/zWniSfTHIwyWNJ3r5knx2t/9NJdqzOcCRJp+pUzgB+G7jiuLbdwP1VtRW4v60DXAlsbT87gdtgEBjAzcA7gEuAm18NDUnSZJw0AKrqy8CLxzVvB/a25b3A1UvaP10DDwIbklwE/DRwX1W9WFUvAfdxYqhIktbQsJ8BzFTV0bb8PDDTljcCzy3pd7i1rdQuSZqQM0d9gKqqJDWOYgCS7GQwfcTMzAzz8/PjeuhVtbCwMDW1rmTXtsUT2mbOXr79eNM49vVwzJbjuKbPpMY2bAC8kOSiqjrapniOtfYjwOYl/Ta1tiPA3HHt88s9cFXtAfYAzM7O1tzc3HLdTjvz8/NMS60ruX73PSe07dq2yK0HTv4yOXTd3CpUtLrWwzFbjuOaPpMa27BTQPuBV6/k2QHcvaT9/e1qoEuBV9pU0ReBdyc5v334++7WJkmakJO+tUvyWQbv3i9McpjB1Ty3AHcmuQF4Frimdb8XuAo4CHwb+ABAVb2Y5D8DX239frWqjv9gWZK0hk4aAFX1vhU2Xb5M3wJuXOFx7gDueF3VSZJWjXcCS1KnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOjVSACT5D0meSPJ4ks8meVOSi5M8lORgks8lOav1fWNbP9i2bxnHACRJwxk6AJJsBP49MFtVbwPOAK4Ffg34eFX9IPAScEPb5Qbgpdb+8dZPkjQho04BnQmcneRM4AeAo8C7gLva9r3A1W15e1unbb88SUZ8fknSkIYOgKo6Avw68OcM/uF/BXgEeLmqFlu3w8DGtrwReK7tu9j6v3nY55ckjSZVNdyOyfnA54F/CbwM/A6Dd/YfbdM8JNkM/H5VvS3J48AVVXW4bftT4B1V9Y3jHncnsBNgZmbmx/ft2zdUfWttYWGBc889d9JljOTAkVdOaJs5G17465Pvu23jeatQ0epaD8dsOY5r+ox7bJdddtkjVTV7sn5njvAc/xz4s6r6C4AkXwDeCWxIcmZ7l78JONL6HwE2A4fblNF5wF8e/6BVtQfYAzA7O1tzc3MjlLh25ufnmZZaV3L97ntOaNu1bZFbD5z8ZXLourlVqGh1rYdjthzHNX0mNbZRPgP4c+DSJD/Q5vIvB54EHgDe2/rsAO5uy/vbOm37l2rY0w9J0shG+QzgIQZTPl8DDrTH2gN8GLgpyUEGc/y3t11uB97c2m8Cdo9QtyRpRKNMAVFVNwM3H9f8DHDJMn3/Bvj5UZ5PkjQ+3gksSZ0yACSpUwaAJHVqpM8ApHHYsszlp6fq0C3vGWMlUl88A5CkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1Cn/JKTGYpQ/6yhpMjwDkKROjRQASTYkuSvJHyd5KslPJLkgyX1Jnm6/z299k+STSQ4meSzJ28czBEnSMEY9A/gE8H+q6oeBfwY8BewG7q+qrcD9bR3gSmBr+9kJ3Dbic0uSRjB0ACQ5D/gp4HaAqvpOVb0MbAf2tm57gavb8nbg0zXwILAhyUVDVy5JGskoZwAXA38B/M8kX0/yqSTnADNVdbT1eR6YacsbgeeW7H+4tUmSJiBVNdyOySzwIPDOqnooySeAbwK/WFUblvR7qarOT/J7wC1V9ZXWfj/w4ap6+LjH3clgioiZmZkf37dv31D1rbWFhQXOPffcSZcxkgNHXjmhbeZseOGvJ1DMKdq28byh910Px2w5jmv6jHtsl1122SNVNXuyfqNcBnoYOFxVD7X1uxjM97+Q5KKqOtqmeI617UeAzUv239Tavk9V7QH2AMzOztbc3NwIJa6d+fl5pqXWlVy/zKWcu7YtcuuB0/dq4UPXzQ2973o4ZstxXNNnUmMbegqoqp4HnkvyQ63pcuBJYD+wo7XtAO5uy/uB97ergS4FXlkyVSRJWmOjvrX7ReAzSc4CngE+wCBU7kxyA/AscE3rey9wFXAQ+HbrK0makJECoKoeBZabZ7p8mb4F3DjK80mSxsc7gSWpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0aOQCSnJHk60l+r61fnOShJAeTfC7JWa39jW39YNu+ZdTnliQNbxxnAB8Cnlqy/mvAx6vqB4GXgBta+w3AS639462fJGlCRgqAJJuA9wCfausB3gXc1brsBa5uy9vbOm375a2/JGkCRj0D+G/ALwPfbetvBl6uqsW2fhjY2JY3As8BtO2vtP6SpAk4c9gdk/wMcKyqHkkyN66CkuwEdgLMzMwwPz8/rodeVQsLC1NT60p2bVs8oW3m7OXbTxej/DdfD8dsOY5r+kxqbEMHAPBO4GeTXAW8CfgHwCeADUnObO/yNwFHWv8jwGbgcJIzgfOAvzz+QatqD7AHYHZ2tubm5kYoce3Mz89zOtS6Zfc9I+x94sth17ZFbj0wystkdR26bm7ofU+XYzZujmv6TGpsQ08BVdVHqmpTVW0BrgW+VFXXAQ8A723ddgB3t+X9bZ22/UtVVcM+vyRpNKtxH8CHgZuSHGQwx397a78deHNrvwnYvQrPLUk6RWM5t6+qeWC+LT8DXLJMn78Bfn4czydJGp13AktSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6tRY/iSkxmfL7nsmXYKkTngGIEmdMgAkqVMGgCR1ygCQpE4ZAJLUqaEDIMnmJA8keTLJE0k+1NovSHJfkqfb7/Nbe5J8MsnBJI8lefu4BiFJev1GOQNYBHZV1VuBS4Ebk7wV2A3cX1VbgfvbOsCVwNb2sxO4bYTnliSNaOgAqKqjVfW1tvxXwFPARmA7sLd12wtc3Za3A5+ugQeBDUkuGrpySdJIxnIjWJItwI8BDwEzVXW0bXoemGnLG4Hnlux2uLUdRRrSKDfO7dq2yNz4SpGmzsgBkORc4PPAL1XVN5N8b1tVVZJ6nY+3k8EUETMzM8zPz49a4ppYWFgYS627ti2OXswYzZx9+tU0LjNnMzWvr9djXK/F0816HRdMbmwjBUCSNzD4x/8zVfWF1vxCkouq6mib4jnW2o8Am5fsvqm1fZ+q2gPsAZidna25ublRSlwz8/PzjKPW60+zr4LYtW2RWw+sz28M2bVtkWum5PX1eozrtXi6Wa/jgsmNbZSrgALcDjxVVb+xZNN+YEdb3gHcvaT9/e1qoEuBV5ZMFUmS1tgob+3eCfwCcCDJo63tV4BbgDuT3AA8C1zTtt0LXAUcBL4NfGCE55YkjWjoAKiqrwBZYfPly/Qv4MZhn2+a+I2ekqbB+pzcldbAKEF/6Jb3jLESaTh+FYQkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlJeBqmves6GeeQYgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQ3gklT5mQ3r+3atviaf1vav0WgV3kGIEmd8gxAmgC/gkKnA88AJKlTBoAkdcoAkKROGQCS1CkDQJI6teZXASW5AvgEcAbwqaq6Za1rkHo2yhVI3kOwvqxpACQ5A/hN4F8Ah4GvJtlfVU+uZR2n4vX+T3Kym2+k9WBSl68aPKtjraeALgEOVtUzVfUdYB+wfY1rkCSx9lNAG4HnlqwfBt6xWk/mzTbS+rBl9z1TeZZ9up+5pKrW7smS9wJXVNW/aeu/ALyjqj64pM9OYGdb/SHgT9aswNFcCHxj0kWsgvU6Lli/Y3Nc02fcY/tHVfWWk3Va6zOAI8DmJeubWtv3VNUeYM9aFjUOSR6uqtlJ1zFu63VcsH7H5rimz6TGttafAXwV2Jrk4iRnAdcC+9e4BkkSa3wGUFWLST4IfJHBZaB3VNUTa1mDJGlgze8DqKp7gXvX+nnXwNRNW52i9TouWL9jc1zTZyJjW9MPgSVJpw+/CkKSOmUAjCjJoSQHkjya5OFJ1zOKJHckOZbk8SVtFyS5L8nT7ff5k6xxGCuM66NJjrTj9miSqyZZ4zCSbE7yQJInkzyR5EOtfT0cs5XGNtXHLcmbkvxhkj9q4/pPrf3iJA8lOZjkc+0imdWvxymg0SQ5BMxW1dRfn5zkp4AF4NNV9bbW9l+AF6vqliS7gfOr6sOTrPP1WmFcHwUWqurXJ1nbKJJcBFxUVV9L8veBR4CrgeuZ/mO20tiuYYqPW5IA51TVQpI3AF8BPgTcBHyhqvYl+R/AH1XVbatdj2cA+p6q+jLw4nHN24G9bXkvg/8Jp8oK45p6VXW0qr7Wlv8KeIrB3fbr4ZitNLapVgMLbfUN7aeAdwF3tfY1O2YGwOgK+IMkj7S7mNebmao62pafB2YmWcyYfTDJY22KaOqmSZZKsgX4MeAh1tkxO25sMOXHLckZSR4FjgH3AX8KvFxVi63LYdYo7AyA0f1kVb0duBK4sU03rEs1mC9cL3OGtwH/BPhR4Chw62TLGV6Sc4HPA79UVd9cum3aj9kyY5v641ZVf1dVP8rgmxAuAX54UrUYACOqqiPt9zHgdxkc0PXkhTYf++q87LEJ1zMWVfVC+x/xu8BvMaXHrc0jfx74TFV9oTWvi2O23NjWy3EDqKqXgQeAnwA2JHn1vqwTviJntRgAI0hyTvuAiiTnAO8GHn/tvabOfmBHW94B3D3BWsbm1X8gm59jCo9b+0DxduCpqvqNJZum/pitNLZpP25J3pJkQ1s+m8HfRnmKQRC8t3Vbs2PmVUAjSPKPGbzrh8Fd1f+rqj42wZJGkuSzwByDbyZ8AbgZ+N/AncA/BJ4FrqmqqfpAdYVxzTGYRijgEPBvl8ybT4UkPwn8X+AA8N3W/CsM5sqn/ZitNLb3McXHLcmPMPiQ9wwGb8DvrKpfbf+W7AMuAL4O/Kuq+ttVr8cAkKQ+OQUkSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6tT/B/8oHIoMCzVuAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"(abalone.rings + 1.5 ).hist(bins=20);"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True 2770\n",
"False 1407\n",
"Name: rings, dtype: int64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((abalone.rings + 1.5) >= 10).value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Look at the cumulative distribution"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_cumsum_df = pd.DataFrame({'Cumulative Sum':((abalone.rings + 1.5).value_counts().sort_index().cumsum())/abalone.shape[0]})\n",
"ax = n_cumsum_df.plot();\n",
"ax.set_xlabel('Age');\n",
"ax.set_ylabel('Normalized Cumulative Sum');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"explore\"></a>\n",
"# 4. Process and explore data\n",
"\n",
"Create a column for maturity"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"4177 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" <th>age</th>\n",
" <th>mature</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>m</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>0.2255</td>\n",
" <td>0.0995</td>\n",
" <td>0.0485</td>\n",
" <td>0.07</td>\n",
" <td>7</td>\n",
" <td>8.5</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>i</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>0.3515</td>\n",
" <td>0.141</td>\n",
" <td>0.0775</td>\n",
" <td>0.12</td>\n",
" <td>8</td>\n",
" <td>9.5</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>f</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>0.8945</td>\n",
" <td>0.3145</td>\n",
" <td>0.151</td>\n",
" <td>0.32</td>\n",
" <td>19</td>\n",
" <td>20.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>f</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>0.6845</td>\n",
" <td>0.2725</td>\n",
" <td>0.171</td>\n",
" <td>0.205</td>\n",
" <td>10</td>\n",
" <td>11.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>f</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>0.451</td>\n",
" <td>0.188</td>\n",
" <td>0.087</td>\n",
" <td>0.13</td>\n",
" <td>10</td>\n",
" <td>11.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>i</td>\n",
" <td>0.38</td>\n",
" <td>0.275</td>\n",
" <td>0.1</td>\n",
" <td>0.2255</td>\n",
" <td>0.08</td>\n",
" <td>0.049</td>\n",
" <td>0.085</td>\n",
" <td>10</td>\n",
" <td>11.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>f</td>\n",
" <td>0.56</td>\n",
" <td>0.44</td>\n",
" <td>0.14</td>\n",
" <td>0.9285</td>\n",
" <td>0.3825</td>\n",
" <td>0.188</td>\n",
" <td>0.3</td>\n",
" <td>11</td>\n",
" <td>12.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>m</td>\n",
" <td>0.575</td>\n",
" <td>0.425</td>\n",
" <td>0.14</td>\n",
" <td>0.8635</td>\n",
" <td>0.393</td>\n",
" <td>0.227</td>\n",
" <td>0.2</td>\n",
" <td>11</td>\n",
" <td>12.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>f</td>\n",
" <td>0.68</td>\n",
" <td>0.55</td>\n",
" <td>0.175</td>\n",
" <td>1.798</td>\n",
" <td>0.815</td>\n",
" <td>0.3925</td>\n",
" <td>0.455</td>\n",
" <td>19</td>\n",
" <td>20.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37</td>\n",
" <td>f</td>\n",
" <td>0.45</td>\n",
" <td>0.355</td>\n",
" <td>0.105</td>\n",
" <td>0.5225</td>\n",
" <td>0.237</td>\n",
" <td>0.1165</td>\n",
" <td>0.145</td>\n",
" <td>8</td>\n",
" <td>9.5</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'm', 0.35, 0.265, 0.09, 0.2255, 0.0995, 0.0485, 0.07, 7, Decimal('8.5'), 0),\n",
" (5, u'i', 0.425, 0.3, 0.095, 0.3515, 0.141, 0.0775, 0.12, 8, Decimal('9.5'), 0),\n",
" (9, u'f', 0.55, 0.44, 0.15, 0.8945, 0.3145, 0.151, 0.32, 19, Decimal('20.5'), 1),\n",
" (13, u'f', 0.535, 0.405, 0.145, 0.6845, 0.2725, 0.171, 0.205, 10, Decimal('11.5'), 1),\n",
" (17, u'f', 0.44, 0.34, 0.1, 0.451, 0.188, 0.087, 0.13, 10, Decimal('11.5'), 1),\n",
" (21, u'i', 0.38, 0.275, 0.1, 0.2255, 0.08, 0.049, 0.085, 10, Decimal('11.5'), 1),\n",
" (25, u'f', 0.56, 0.44, 0.14, 0.9285, 0.3825, 0.188, 0.3, 11, Decimal('12.5'), 1),\n",
" (29, u'm', 0.575, 0.425, 0.14, 0.8635, 0.393, 0.227, 0.2, 11, Decimal('12.5'), 1),\n",
" (33, u'f', 0.68, 0.55, 0.175, 1.798, 0.815, 0.3925, 0.455, 19, Decimal('20.5'), 1),\n",
" (37, u'f', 0.45, 0.355, 0.105, 0.5225, 0.237, 0.1165, 0.145, 8, Decimal('9.5'), 0)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_target;\n",
"CREATE TABLE abalone_target\n",
"AS\n",
"SELECT \n",
" id,\n",
" lower(sex) as sex, --- ensure that sex is indicated in lower case (m, f, i)\n",
" \"length\",\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" rings,\n",
" rings + 1.5 as age, --- Define the age\n",
" CASE WHEN --- Identifies whether the abalone is mature or not (1/0)\n",
" (rings + 1.5) >= 10.0\n",
" THEN 1\n",
" ELSE 0\n",
" END as mature\n",
"FROM abalone\n",
";\n",
"SELECT * FROM abalone_target LIMIT 10; --- show a sample"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Encode categorical variables\n",
"\n",
"Next use [MADlib to one-hot encode](http://madlib.apache.org/docs/latest/group__grp__encode__categorical.html) the “sex” column which is a categorical variable. In order to create a predictive model, we need all our columns to be numerical values. Making sure all our model inputs conform to this standard is an important part of the data science modeling pipeline and is considered part of the preprocessing/data cleaning step of the process."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" <th>age</th>\n",
" <th>mature</th>\n",
" <th>sex_f</th>\n",
" <th>sex_i</th>\n",
" <th>sex_m</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>0.516</td>\n",
" <td>0.2155</td>\n",
" <td>0.114</td>\n",
" <td>0.155</td>\n",
" <td>10</td>\n",
" <td>11.5</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>0.406</td>\n",
" <td>0.1675</td>\n",
" <td>0.081</td>\n",
" <td>0.135</td>\n",
" <td>10</td>\n",
" <td>11.5</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>0.381</td>\n",
" <td>0.1705</td>\n",
" <td>0.075</td>\n",
" <td>0.115</td>\n",
" <td>9</td>\n",
" <td>10.5</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>0.59</td>\n",
" <td>0.445</td>\n",
" <td>0.14</td>\n",
" <td>0.931</td>\n",
" <td>0.356</td>\n",
" <td>0.234</td>\n",
" <td>0.28</td>\n",
" <td>12</td>\n",
" <td>13.5</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>0.465</td>\n",
" <td>0.355</td>\n",
" <td>0.105</td>\n",
" <td>0.4795</td>\n",
" <td>0.227</td>\n",
" <td>0.124</td>\n",
" <td>0.125</td>\n",
" <td>8</td>\n",
" <td>9.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>0.205</td>\n",
" <td>0.15</td>\n",
" <td>0.055</td>\n",
" <td>0.042</td>\n",
" <td>0.0255</td>\n",
" <td>0.015</td>\n",
" <td>0.012</td>\n",
" <td>5</td>\n",
" <td>6.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>0.4</td>\n",
" <td>0.32</td>\n",
" <td>0.095</td>\n",
" <td>0.303</td>\n",
" <td>0.1335</td>\n",
" <td>0.06</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>8.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>0.505</td>\n",
" <td>0.4</td>\n",
" <td>0.125</td>\n",
" <td>0.583</td>\n",
" <td>0.246</td>\n",
" <td>0.13</td>\n",
" <td>0.175</td>\n",
" <td>7</td>\n",
" <td>8.5</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>0.595</td>\n",
" <td>0.495</td>\n",
" <td>0.185</td>\n",
" <td>1.285</td>\n",
" <td>0.416</td>\n",
" <td>0.224</td>\n",
" <td>0.485</td>\n",
" <td>13</td>\n",
" <td>14.5</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>0.6</td>\n",
" <td>0.475</td>\n",
" <td>0.15</td>\n",
" <td>1.0075</td>\n",
" <td>0.4425</td>\n",
" <td>0.221</td>\n",
" <td>0.28</td>\n",
" <td>15</td>\n",
" <td>16.5</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3, 0.44, 0.365, 0.125, 0.516, 0.2155, 0.114, 0.155, 10, Decimal('11.5'), 1, 0, 0, 1),\n",
" (11, 0.43, 0.35, 0.11, 0.406, 0.1675, 0.081, 0.135, 10, Decimal('11.5'), 1, 0, 0, 1),\n",
" (19, 0.45, 0.32, 0.1, 0.381, 0.1705, 0.075, 0.115, 9, Decimal('10.5'), 1, 0, 0, 1),\n",
" (27, 0.59, 0.445, 0.14, 0.931, 0.356, 0.234, 0.28, 12, Decimal('13.5'), 1, 0, 0, 1),\n",
" (35, 0.465, 0.355, 0.105, 0.4795, 0.227, 0.124, 0.125, 8, Decimal('9.5'), 0, 0, 0, 1),\n",
" (43, 0.205, 0.15, 0.055, 0.042, 0.0255, 0.015, 0.012, 5, Decimal('6.5'), 0, 0, 1, 0),\n",
" (51, 0.4, 0.32, 0.095, 0.303, 0.1335, 0.06, 0.1, 7, Decimal('8.5'), 0, 0, 0, 1),\n",
" (59, 0.505, 0.4, 0.125, 0.583, 0.246, 0.13, 0.175, 7, Decimal('8.5'), 0, 1, 0, 0),\n",
" (67, 0.595, 0.495, 0.185, 1.285, 0.416, 0.224, 0.485, 13, Decimal('14.5'), 1, 1, 0, 0),\n",
" (75, 0.6, 0.475, 0.15, 1.0075, 0.4425, 0.221, 0.28, 15, Decimal('16.5'), 1, 1, 0, 0)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_encoded;\n",
"SELECT\n",
"madlib.encode_categorical_variables (\n",
" 'abalone_target', -- input table\n",
" 'abalone_encoded', -- output table\n",
" 'sex', -- categorical_cols\n",
" NULL, --categorical_cols_to_exclude -- Optional\n",
" NULL, --row_id, -- Optional\n",
" NULL, --top, -- Optional\n",
" NULL, --value_to_drop, -- Optional\n",
" NULL, --encode_null, -- Optional\n",
" NULL, --output_type, -- Optional\n",
" NULL, --output_dictionary, -- Optional\n",
" NULL --distributed_by -- Optional\n",
" )\n",
";\n",
"SELECT *\n",
"FROM abalone_encoded\n",
"LIMIT 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Statistics\n",
"\n",
"The next step through the modeling process is to explore our data. We’ll again use some of MADLib’s built in functionality to generate [descriptive statistics](http://madlib.apache.org/docs/latest/group__grp__summary.html) of our data. This will generate important information about the data including count, number of missing values, the mean, median, maximum, minimum, interquartile range, mode, and variance.\n",
"\n",
"Note that you only want to do this after converting categorical data to numeric data because otherwise the statistics will not be compute correctly."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"14 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>group_by</th>\n",
" <th>group_by_value</th>\n",
" <th>target_column</th>\n",
" <th>column_number</th>\n",
" <th>data_type</th>\n",
" <th>row_count</th>\n",
" <th>distinct_values</th>\n",
" <th>missing_values</th>\n",
" <th>blank_values</th>\n",
" <th>fraction_missing</th>\n",
" <th>fraction_blank</th>\n",
" <th>positive_values</th>\n",
" <th>negative_values</th>\n",
" <th>zero_values</th>\n",
" <th>mean</th>\n",
" <th>variance</th>\n",
" <th>confidence_interval</th>\n",
" <th>min</th>\n",
" <th>max</th>\n",
" <th>first_quartile</th>\n",
" <th>median</th>\n",
" <th>third_quartile</th>\n",
" <th>most_frequent_values</th>\n",
" <th>mfv_frequencies</th>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>id</td>\n",
" <td>1</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4176</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2088.0</td>\n",
" <td>1454292.16667</td>\n",
" <td>[2051.42791957426, 2124.57208042574]</td>\n",
" <td>0.0</td>\n",
" <td>4176.0</td>\n",
" <td>1044.0</td>\n",
" <td>2088.0</td>\n",
" <td>3132.0</td>\n",
" <td>[u'3453', u'4114', u'2806', u'2823', u'2700', u'2914', u'2970', u'3559', u'3663', u'3809']</td>\n",
" <td>[5L, 5L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>length</td>\n",
" <td>2</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>134</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.523992099593</td>\n",
" <td>0.0144223076483</td>\n",
" <td>[0.520350088942874, 0.527634110243145]</td>\n",
" <td>0.075</td>\n",
" <td>0.815</td>\n",
" <td>0.45</td>\n",
" <td>0.545</td>\n",
" <td>0.615</td>\n",
" <td>[u'0.55', u'0.625', u'0.575', u'0.58', u'0.62', u'0.6', u'0.5', u'0.57', u'0.63', u'0.61']</td>\n",
" <td>[94L, 94L, 93L, 92L, 87L, 87L, 81L, 79L, 78L, 75L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>diameter</td>\n",
" <td>3</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>111</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.407881254489</td>\n",
" <td>0.00984855103022</td>\n",
" <td>[0.404871645997762, 0.410890862979976]</td>\n",
" <td>0.055</td>\n",
" <td>0.65</td>\n",
" <td>0.35</td>\n",
" <td>0.425</td>\n",
" <td>0.48</td>\n",
" <td>[u'0.45', u'0.475', u'0.4', u'0.5', u'0.47', u'0.48', u'0.455', u'0.46', u'0.44', u'0.485']</td>\n",
" <td>[139L, 120L, 111L, 110L, 100L, 91L, 90L, 89L, 87L, 83L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>height</td>\n",
" <td>4</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4175</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0.13951639933</td>\n",
" <td>0.00174950266443</td>\n",
" <td>[0.138247926591561, 0.140784872067761]</td>\n",
" <td>0.0</td>\n",
" <td>1.13</td>\n",
" <td>0.115</td>\n",
" <td>0.14</td>\n",
" <td>0.165</td>\n",
" <td>[u'0.15', u'0.14', u'0.155', u'0.175', u'0.16', u'0.125', u'0.165', u'0.135', u'0.145', u'0.12']</td>\n",
" <td>[267L, 220L, 217L, 211L, 205L, 202L, 193L, 189L, 182L, 169L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>whole_weight</td>\n",
" <td>5</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>2429</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.828742159445</td>\n",
" <td>0.240481389202</td>\n",
" <td>[0.813870324055101, 0.843613994834057]</td>\n",
" <td>0.002</td>\n",
" <td>2.8255</td>\n",
" <td>0.4415</td>\n",
" <td>0.7995</td>\n",
" <td>1.153</td>\n",
" <td>[u'1.1345', u'0.2225', u'0.196', u'0.4425', u'0.4775', u'1.0835', u'0.97', u'0.874', u'1.1155', u'0.872']</td>\n",
" <td>[10L, 8L, 8L, 7L, 7L, 7L, 7L, 7L, 7L, 7L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>shucked_weight</td>\n",
" <td>6</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>1515</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.359367488628</td>\n",
" <td>0.0492675507435</td>\n",
" <td>[0.352636105342962, 0.366098871913441]</td>\n",
" <td>0.001</td>\n",
" <td>1.488</td>\n",
" <td>0.186</td>\n",
" <td>0.336</td>\n",
" <td>0.502</td>\n",
" <td>[u'0.175', u'0.2505', u'0.21', u'0.419', u'0.2', u'0.097', u'0.165', u'0.096', u'0.0745', u'0.2025']</td>\n",
" <td>[11L, 10L, 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>viscera_weight</td>\n",
" <td>7</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>880</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.180593607853</td>\n",
" <td>0.01201528386</td>\n",
" <td>[0.17726937948362, 0.183917836221431]</td>\n",
" <td>0.0005</td>\n",
" <td>0.76</td>\n",
" <td>0.0935</td>\n",
" <td>0.171</td>\n",
" <td>0.253</td>\n",
" <td>[u'0.1715', u'0.196', u'0.0575', u'0.2195', u'0.061', u'0.037', u'0.096', u'0.1405', u'0.099', u'0.0265']</td>\n",
" <td>[15L, 14L, 13L, 13L, 13L, 13L, 12L, 12L, 12L, 12L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>shell_weight</td>\n",
" <td>8</td>\n",
" <td>float8</td>\n",
" <td>4177</td>\n",
" <td>926</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.238830859469</td>\n",
" <td>0.0193773832022</td>\n",
" <td>[0.234609314715373, 0.243052404221663]</td>\n",
" <td>0.0015</td>\n",
" <td>1.005</td>\n",
" <td>0.13</td>\n",
" <td>0.234</td>\n",
" <td>0.329</td>\n",
" <td>[u'0.275', u'0.25', u'0.185', u'0.265', u'0.315', u'0.3', u'0.17', u'0.285', u'0.175', u'0.22']</td>\n",
" <td>[43L, 42L, 40L, 40L, 40L, 37L, 37L, 37L, 36L, 36L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>rings</td>\n",
" <td>9</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>28</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>9.93368446253</td>\n",
" <td>10.3952659473</td>\n",
" <td>[9.83590635305212, 10.0314625720137]</td>\n",
" <td>1.0</td>\n",
" <td>29.0</td>\n",
" <td>8.0</td>\n",
" <td>9.0</td>\n",
" <td>11.0</td>\n",
" <td>[u'9', u'10', u'8', u'11', u'7', u'12', u'6', u'13', u'14', u'5']</td>\n",
" <td>[689L, 634L, 568L, 487L, 391L, 267L, 259L, 203L, 126L, 115L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>age</td>\n",
" <td>10</td>\n",
" <td>numeric</td>\n",
" <td>4177</td>\n",
" <td>28</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>4177</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>11.4336844625</td>\n",
" <td>10.3952659473</td>\n",
" <td>[11.3359063530521, 11.5314625720137]</td>\n",
" <td>2.5</td>\n",
" <td>30.5</td>\n",
" <td>9.5</td>\n",
" <td>10.5</td>\n",
" <td>12.5</td>\n",
" <td>[u'10.5', u'11.5', u'9.5', u'12.5', u'8.5', u'13.5', u'7.5', u'14.5', u'15.5', u'6.5']</td>\n",
" <td>[689L, 634L, 568L, 487L, 391L, 267L, 259L, 203L, 126L, 115L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>mature</td>\n",
" <td>11</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>2770</td>\n",
" <td>0</td>\n",
" <td>1407</td>\n",
" <td>0.663155374671</td>\n",
" <td>0.223433815173</td>\n",
" <td>[0.648820355293342, 0.67749039404829]</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>[u'1', u'0']</td>\n",
" <td>[2770L, 1407L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>sex_f</td>\n",
" <td>12</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>1307</td>\n",
" <td>0</td>\n",
" <td>2870</td>\n",
" <td>0.312903998085</td>\n",
" <td>0.215046569565</td>\n",
" <td>[0.298840605732167, 0.326967390437333]</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>[u'0', u'1']</td>\n",
" <td>[2870L, 1307L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>sex_i</td>\n",
" <td>13</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>1342</td>\n",
" <td>0</td>\n",
" <td>2835</td>\n",
" <td>0.32128321762</td>\n",
" <td>0.218112529203</td>\n",
" <td>[0.30711992784858, 0.335446507392023]</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>[u'0', u'1']</td>\n",
" <td>[2835L, 1342L]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>sex_m</td>\n",
" <td>14</td>\n",
" <td>int4</td>\n",
" <td>4177</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>0.0</td>\n",
" <td>None</td>\n",
" <td>1528</td>\n",
" <td>0</td>\n",
" <td>2649</td>\n",
" <td>0.365812784295</td>\n",
" <td>0.23204934521</td>\n",
" <td>[0.351204002328337, 0.380421566261561]</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>[u'0', u'1']</td>\n",
" <td>[2649L, 1528L]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(None, None, u'id', 1, u'int4', 4177L, 4177L, 0L, None, 0.0, None, 4176L, 0L, 1L, 2088.0, 1454292.16666667, [2051.42791957426, 2124.57208042574], 0.0, 4176.0, 1044.0, 2088.0, 3132.0, [u'3453', u'4114', u'2806', u'2823', u'2700', u'2914', u'2970', u'3559', u'3663', u'3809'], [5L, 5L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L]),\n",
" (None, None, u'length', 2, u'float8', 4177L, 134L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.523992099593009, 0.0144223076482969, [0.520350088942874, 0.527634110243145], 0.075, 0.815, 0.45, 0.545, 0.615, [u'0.55', u'0.625', u'0.575', u'0.58', u'0.62', u'0.6', u'0.5', u'0.57', u'0.63', u'0.61'], [94L, 94L, 93L, 92L, 87L, 87L, 81L, 79L, 78L, 75L]),\n",
" (None, None, u'diameter', 3, u'float8', 4177L, 111L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.407881254488869, 0.00984855103022442, [0.404871645997762, 0.410890862979976], 0.055, 0.65, 0.35, 0.425, 0.48, [u'0.45', u'0.475', u'0.4', u'0.5', u'0.47', u'0.48', u'0.455', u'0.46', u'0.44', u'0.485'], [139L, 120L, 111L, 110L, 100L, 91L, 90L, 89L, 87L, 83L]),\n",
" (None, None, u'height', 4, u'float8', 4177L, 51L, 0L, None, 0.0, None, 4175L, 0L, 2L, 0.139516399329661, 0.00174950266442686, [0.138247926591561, 0.140784872067761], 0.0, 1.13, 0.115, 0.14, 0.165, [u'0.15', u'0.14', u'0.155', u'0.175', u'0.16', u'0.125', u'0.165', u'0.135', u'0.145', u'0.12'], [267L, 220L, 217L, 211L, 205L, 202L, 193L, 189L, 182L, 169L]),\n",
" (None, None, u'whole_weight', 5, u'float8', 4177L, 2429L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.828742159444579, 0.240481389201558, [0.813870324055101, 0.843613994834057], 0.002, 2.8255, 0.4415, 0.7995, 1.153, [u'1.1345', u'0.2225', u'0.196', u'0.4425', u'0.4775', u'1.0835', u'0.97', u'0.874', u'1.1155', u'0.872'], [10L, 8L, 8L, 7L, 7L, 7L, 7L, 7L, 7L, 7L]),\n",
" (None, None, u'shucked_weight', 6, u'float8', 4177L, 1515L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.359367488628202, 0.0492675507435241, [0.352636105342962, 0.366098871913441], 0.001, 1.488, 0.186, 0.336, 0.502, [u'0.175', u'0.2505', u'0.21', u'0.419', u'0.2', u'0.097', u'0.165', u'0.096', u'0.0745', u'0.2025'], [11L, 10L, 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L]),\n",
" (None, None, u'viscera_weight', 7, u'float8', 4177L, 880L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.180593607852526, 0.0120152838599929, [0.17726937948362, 0.183917836221431], 0.0005, 0.76, 0.0935, 0.171, 0.253, [u'0.1715', u'0.196', u'0.0575', u'0.2195', u'0.061', u'0.037', u'0.096', u'0.1405', u'0.099', u'0.0265'], [15L, 14L, 13L, 13L, 13L, 13L, 12L, 12L, 12L, 12L]),\n",
" (None, None, u'shell_weight', 8, u'float8', 4177L, 926L, 0L, None, 0.0, None, 4177L, 0L, 0L, 0.238830859468518, 0.0193773832021588, [0.234609314715373, 0.243052404221663], 0.0015, 1.005, 0.13, 0.234, 0.329, [u'0.275', u'0.25', u'0.185', u'0.265', u'0.315', u'0.3', u'0.17', u'0.285', u'0.175', u'0.22'], [43L, 42L, 40L, 40L, 40L, 37L, 37L, 37L, 36L, 36L]),\n",
" (None, None, u'rings', 9, u'int4', 4177L, 28L, 0L, None, 0.0, None, 4177L, 0L, 0L, 9.93368446253292, 10.3952659473471, [9.83590635305212, 10.0314625720137], 1.0, 29.0, 8.0, 9.0, 11.0, [u'9', u'10', u'8', u'11', u'7', u'12', u'6', u'13', u'14', u'5'], [689L, 634L, 568L, 487L, 391L, 267L, 259L, 203L, 126L, 115L]),\n",
" (None, None, u'age', 10, u'numeric', 4177L, 28L, 0L, None, 0.0, None, 4177L, 0L, 0L, 11.4336844625329, 10.3952659473471, [11.3359063530521, 11.5314625720137], 2.5, 30.5, 9.5, 10.5, 12.5, [u'10.5', u'11.5', u'9.5', u'12.5', u'8.5', u'13.5', u'7.5', u'14.5', u'15.5', u'6.5'], [689L, 634L, 568L, 487L, 391L, 267L, 259L, 203L, 126L, 115L]),\n",
" (None, None, u'mature', 11, u'int4', 4177L, 2L, 0L, None, 0.0, None, 2770L, 0L, 1407L, 0.663155374670816, 0.223433815172854, [0.648820355293342, 0.67749039404829], 0.0, 1.0, 0.0, 1.0, 1.0, [u'1', u'0'], [2770L, 1407L]),\n",
" (None, None, u'sex_f', 12, u'int4', 4177L, 2L, 0L, None, 0.0, None, 1307L, 0L, 2870L, 0.31290399808475, 0.21504656956495, [0.298840605732167, 0.326967390437333], 0.0, 1.0, 0.0, 0.0, 1.0, [u'0', u'1'], [2870L, 1307L]),\n",
" (None, None, u'sex_i', 13, u'int4', 4177L, 2L, 0L, None, 0.0, None, 1342L, 0L, 2835L, 0.321283217620302, 0.218112529203438, [0.30711992784858, 0.335446507392023], 0.0, 1.0, 0.0, 0.0, 1.0, [u'0', u'1'], [2835L, 1342L]),\n",
" (None, None, u'sex_m', 14, u'int4', 4177L, 2L, 0L, None, 0.0, None, 1528L, 0L, 2649L, 0.365812784294949, 0.232049345210086, [0.351204002328337, 0.380421566261561], 0.0, 1.0, 0.0, 0.0, 1.0, [u'0', u'1'], [2649L, 1528L])]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_summary ;\n",
"SELECT madlib.summary (\n",
" 'abalone_encoded', -- source_table\n",
" 'abalone_summary', -- output_table\n",
" NULL, -- target_cols\n",
" NULL, -- grouping_cols\n",
" TRUE, -- get_distinct\n",
" TRUE, -- get_quartiles\n",
" NULL, -- quantile_array\n",
" 10, -- how_many_mfv\n",
" FALSE -- get_estimate\n",
")\n",
";\n",
"\n",
"SELECT * FROM abalone_summary LIMIT 15\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Correlation\n",
"Another aspect of the data that we might want to know about is the correlation between different columns. We turn again to MADlib to provide a ready made function: [correlation()](http://madlib.apache.org/docs/latest/group__grp__correlation.html)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>correlation</th>\n",
" </tr>\n",
" <tr>\n",
" <td>Summary for 'Correlation' function<br>Output table = abalone_correlations<br>Grouping columns: sex_f,sex_i,sex_m<br>Producing correlation for columns: length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings<br>Total run time = ('abalone_correlations', 8, 0.21109700202941895)</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u\"Summary for 'Correlation' function\\nOutput table = abalone_correlations\\nGrouping columns: sex_f,sex_i,sex_m\\nProducing correlation for columns: length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings\\nTotal run time = ('abalone_correlations', 8, 0.21109700202941895)\",)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_correlations;\n",
"DROP TABLE IF EXISTS abalone_correlations_summary;\n",
"\n",
"SELECT\n",
"madlib.correlation(\n",
" 'abalone_encoded', -- source_table,\n",
" 'abalone_correlations', -- output_table,\n",
" 'length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings', -- target_cols,\n",
" TRUE, -- verbose,\n",
" 'sex_f,sex_i,sex_m' -- grouping_columns\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8 rows affected.\n"
]
}
],
"source": [
"Index = %sql SELECT variable FROM abalone_correlations WHERE sex_m = '1' ORDER BY column_position;\n",
"Index = Index.DataFrame();"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings\n"
]
}
],
"source": [
"columns = ','.join(','.join('%s' %x for x in y) for y in Index.values)\n",
"print(columns)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8 rows affected.\n"
]
}
],
"source": [
"correlations_male = %sql select variable,{columns} from abalone_correlations where sex_m = '1' ORDER BY column_position;\n",
"correlations_male = correlations_male.DataFrame()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"corm = correlations_male.set_index('variable')\n",
"ax = sns.heatmap(corm)\n",
"ax.set_title('Correlations for Male Abalone')\n",
"plt.show();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling for Training and Testing\n",
"Ensuring predictive power in large part is the result of creating a hold-out data set that we don’t train our model with. By creating this subset of the data, we can test any model we develop against “unseen” data to prevent overfitting by our model. This has the benefit of generating a predictive model that will generalize better.\n",
"\n",
"There’s no right answer as to how much data to set aside in the test table; a 70&-30% split, weighted towards the training data, is a good rule of thumb. This process is referred to as the [train-test split](http://madlib.apache.org/docs/latest/group__grp__train__test__split.html).\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>Number of abalone in the Training Set</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2924</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2924L,)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_classif CASCADE;\n",
"DROP TABLE IF EXISTS abalone_classif_train CASCADE;\n",
"DROP TABLE IF EXISTS abalone_classif_test CASCADE;\n",
"SELECT madlib.train_test_split(\n",
" 'abalone_encoded', -- source_table,\n",
" 'abalone_classif', -- output_table,\n",
" 0.7, -- train_proportion,\n",
" NULL, -- test_proportion,\n",
" NULL, -- grouping_cols,\n",
" 'id,length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,sex_f,sex_i,sex_m,rings,age,mature', -- target_cols,\n",
" FALSE, -- with_replacement,\n",
" TRUE -- separate_output_tables\n",
")\n",
";\n",
"SELECT count(*) as \"Number of abalone in the Training Set\" FROM abalone_classif_train\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"classification\"></a>\n",
"# 5. Classification models\n",
"\n",
"<a id=\"logistic\"></a>\n",
"## 5a. Logistic Regression\n",
"\n",
"We’re now ready to create our first predictive model. We’ll start with a classic [logistic regression](http://madlib.apache.org/docs/latest/group__grp__logreg.html).\n",
"\n",
"Note: we drop one of the 1-hot-encoded variables to remove perfect colinearity"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>method</th>\n",
" <th>source_table</th>\n",
" <th>out_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>optimizer_params</th>\n",
" <th>num_all_groups</th>\n",
" <th>num_failed_groups</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_missing_rows_skipped</th>\n",
" <th>grouping_col</th>\n",
" </tr>\n",
" <tr>\n",
" <td>logregr</td>\n",
" <td>abalone_classif_train</td>\n",
" <td>abalone_logreg_model</td>\n",
" <td>mature</td>\n",
" <td>ARRAY[<br> 1,<br> length,<br> diameter,<br> height,<br> whole_weight,<br> shucked_weight,<br> viscera_weight,<br> shell_weight,<br> sex_f,<br> sex_m<br> ]</td>\n",
" <td>optimizer=irls, max_iter=20, tolerance=0.0001</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2924</td>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'logregr', u'abalone_classif_train', u'abalone_logreg_model', u'mature', u'ARRAY[\\n 1,\\n length,\\n diameter,\\n height,\\n whole_weight,\\n shucked_weight,\\n viscera_weight,\\n shell_weight,\\n sex_f,\\n sex_m\\n ]', u'optimizer=irls, max_iter=20, tolerance=0.0001', 1, 0, 2924, 0, None)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_logreg_model;\n",
"DROP TABLE IF EXISTS abalone_logreg_model_summary;\n",
"\n",
"SELECT\n",
" madlib.logregr_train( --- Train the Logistic Regression Model\n",
" 'abalone_classif_train', -- source_table,\n",
" 'abalone_logreg_model', -- out_table,\n",
" 'mature', -- dependent_varname,\n",
" 'ARRAY[\n",
" 1,\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ]' -- independent_varname,\n",
" --, -- grouping_cols,\n",
" --, -- max_iter,\n",
" --, -- optimizer,\n",
" --, -- tolerance,\n",
" -- verbose\n",
" )\n",
";\n",
"SELECT * FROM abalone_logreg_model_summary ; --- Get the summary table"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Review model weights"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>log_likelihood</th>\n",
" <th>std_err</th>\n",
" <th>z_stats</th>\n",
" <th>p_values</th>\n",
" <th>odds_ratios</th>\n",
" <th>condition_no</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_missing_rows_skipped</th>\n",
" <th>num_iterations</th>\n",
" <th>variance_covariance</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[-4.67737274255474, -0.392212133228859, 6.15510638669522, 2.08463990575267, 4.87504773295541, -13.7623409061431, -1.36671380800461, 17.1547402958807, 1.01917554661788, 0.942108041681632]</td>\n",
" <td>-1006.98331724</td>\n",
" <td>[0.590840441026775, 2.97184057268197, 3.73897571658262, 2.03371620613073, 1.83990223675481, 2.03448594937021, 3.0108856938889, 2.83197606788329, 0.152140116528269, 0.139164221999888]</td>\n",
" <td>[-7.91647358198146, -0.131976168854476, 1.64620122013548, 1.02503972750398, 2.64962324387079, -6.76452983634679, -0.453924176124849, 6.05751598342522, 6.69892708034377, 6.76975754359042]</td>\n",
" <td>[2.44341625367833e-15, 0.895003141346682, 0.0997223381574991, 0.305344442807183, 0.00805815758922618, 1.33742797426474e-11, 0.649883402076758, 1.38239609928844e-09, 2.0995532130572e-11, 1.28998390096206e-11]</td>\n",
" <td>[0.00930342429755957, 0.675560789867617, 471.116960371319, 8.04169519208295, 130.980405143444, 1.05460854620395e-06, 0.254943377371314, 28197398.9511836, 2.77090933640556, 2.56538365793589]</td>\n",
" <td>121.997664048</td>\n",
" <td>2924</td>\n",
" <td>0</td>\n",
" <td>7</td>\n",
" <td>[[0.349092426752714, -0.675574747128063, -0.397935077415633, -0.180020641088026, 0.147499936799071, -0.0153052244481757, 0.122866412950121, 0.34417106132921, -0.00868455748700556, -0.0138313591702408], [-0.675574747128063, 8.83183638943869, -9.04613626962293, -0.0601147873774726, -0.411410454286909, 0.158792353760485, -0.555224429178143, 0.279153789377445, 0.0398313893638008, 0.0352908716573887], [-0.397935077415633, -9.04613626962293, 13.9799394091945, -0.517286514331738, -0.0639955675289591, -0.0262239231108973, 0.360627238739412, -2.21869974874036, -0.0254340823712027, -0.00678581818134909], [-0.180020641088026, -0.0601147873774725, -0.517286514331738, 4.13600160707876, -0.0853550013704344, 0.0757036887508739, -0.164040489691675, -0.341456748118747, -0.0224655554447562, -0.0070919467634824], [0.147499936799071, -0.411410454286909, -0.0639955675289595, -0.0853550013704344, 3.38524024081537, -3.33519672860865, -3.49335772859094, -3.39500733031394, -0.0262681185753689, -0.0189083677062056], [-0.0153052244481757, 0.158792353760485, -0.0262239231108971, 0.0757036887508739, -3.33519672860865, 4.13913307818482, 2.35314654677864, 2.96795813139297, 0.0299367195410633, 0.00974680695082926], [0.122866412950121, -0.555224429178143, 0.360627238739412, -0.164040489691675, -3.49335772859094, 2.35314654677864, 9.06543266166485, 1.91215690195071, -0.0296918118533269, -0.0318284105816695], [0.34417106132921, 0.279153789377445, -2.21869974874036, -0.341456748118747, -3.39500733031394, 2.96795813139297, 1.91215690195071, 8.0200884490637, 0.027466190787868, 0.0262589069849619], [-0.00868455748700556, 0.0398313893638008, -0.0254340823712027, -0.0224655554447562, -0.0262681185753689, 0.0299367195410632, -0.0296918118533269, 0.027466190787868, 0.0231466150572352, 0.00991437265759586], [-0.0138313591702408, 0.0352908716573887, -0.00678581818134908, -0.0070919467634824, -0.0189083677062056, 0.00974680695082926, -0.0318284105816695, 0.0262589069849619, 0.00991437265759586, 0.019366680684834]]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([-4.67737274255474, -0.392212133228859, 6.15510638669522, 2.08463990575267, 4.87504773295541, -13.7623409061431, -1.36671380800461, 17.1547402958807, 1.01917554661788, 0.942108041681632], -1006.98331724464, [0.590840441026775, 2.97184057268197, 3.73897571658262, 2.03371620613073, 1.83990223675481, 2.03448594937021, 3.0108856938889, 2.83197606788329, 0.152140116528269, 0.139164221999888], [-7.91647358198146, -0.131976168854476, 1.64620122013548, 1.02503972750398, 2.64962324387079, -6.76452983634679, -0.453924176124849, 6.05751598342522, 6.69892708034377, 6.76975754359042], [2.44341625367833e-15, 0.895003141346682, 0.0997223381574991, 0.305344442807183, 0.00805815758922618, 1.33742797426474e-11, 0.649883402076758, 1.38239609928844e-09, 2.0995532130572e-11, 1.28998390096206e-11], [0.00930342429755957, 0.675560789867617, 471.116960371319, 8.04169519208295, 130.980405143444, 1.05460854620395e-06, 0.254943377371314, 28197398.9511836, 2.77090933640556, 2.56538365793589], 121.997664047952, 2924L, 0L, 7, [[0.349092426752714, -0.675574747128063, -0.397935077415633, -0.180020641088026, 0.147499936799071, -0.0153052244481757, 0.122866412950121, 0.34417106 ... (1703 characters truncated) ... 4908, -0.0070919467634824, -0.0189083677062056, 0.00974680695082926, -0.0318284105816695, 0.0262589069849619, 0.00991437265759586, 0.019366680684834]])]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_logreg_model ;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Show coefficients from model"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
}
],
"source": [
"logreg_coefs = %sql SELECT coef FROM abalone_logreg_model ;\n",
"logreg_coefs = logreg_coefs.DataFrame();"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(('intercept', -4.67737274255474),\n",
" ('length', -0.392212133228859),\n",
" ('diameter', 6.15510638669522),\n",
" ('height', 2.08463990575267),\n",
" ('whole_weight', 4.87504773295541),\n",
" ('shucked_weight', -13.7623409061431),\n",
" ('viscera_weight', -1.36671380800461),\n",
" ('shell_weight', 17.1547402958807),\n",
" ('sex_f', 1.01917554661788),\n",
" ('sex_m', 0.942108041681632))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logreg_coef_names = (\n",
" 'intercept',\n",
" 'length',\n",
" 'diameter',\n",
" 'height',\n",
" 'whole_weight',\n",
" 'shucked_weight',\n",
" 'viscera_weight',\n",
" 'shell_weight',\n",
" 'sex_f',\n",
" 'sex_m'\n",
")\n",
"tuple(zip(logreg_coef_names, logreg_coefs.iloc[0, 0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Score the model against test data and join with the test data\n",
"\n",
"Now that we have a model with coefficients, we can make predictions on records previously unseen by the model. In the current version of MADlib (1.16), the way to predict probability using a logistic regression model is to `CROSS JOIN` the test set records with the single-row model table. A `CROSS JOIN` produces the cartesian product between all records in both tables, meaning it pairs every record from one table with every record in the other table. In Postgres/Greenplum this can be done be explicitly using the `CROSS JOIN` statement, or you can simply list the two tables in the `FROM` clause separated by a comma. "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1253 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>proba</th>\n",
" <th>mature</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.991068065265</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.826655523716</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.608566963248</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.916620487399</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.309092601567</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.516539096279</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.5456252611</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.318211176601</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.992267215094</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.831075550203</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.99106806526485, 1),\n",
" (0.826655523716257, 1),\n",
" (0.608566963247747, 1),\n",
" (0.91662048739902, 1),\n",
" (0.309092601566622, 1),\n",
" (0.516539096278762, 1),\n",
" (0.545625261099759, 0),\n",
" (0.318211176600757, 0),\n",
" (0.992267215093604, 1),\n",
" (0.831075550203186, 1)]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_logreg_test_proba;\n",
"CREATE TABLE abalone_logreg_test_proba\n",
"AS\n",
"SELECT madlib.logregr_predict_prob( --- Use the logistic regression model to estimate probability of mature\n",
" coef, \n",
" ARRAY[\n",
" 1,\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ] \n",
" ) as proba,\n",
" test.mature\n",
"FROM abalone_classif_test test, abalone_logreg_model model\n",
";\n",
"\n",
"SELECT * --- take a look at a few of the values\n",
"FROM abalone_logreg_test_proba\n",
"LIMIT 10\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Look at the top model calculated probability vs the actual maturity"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>probability</th>\n",
" <th>mature</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999999986635</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999999582192</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.99999808571</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999993019587</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.99999142131</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999989619946</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999989488185</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999979937474</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999976591803</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.999971356358</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.999999986634653, 1),\n",
" (0.999999582191841, 1),\n",
" (0.999998085709951, 1),\n",
" (0.99999301958695, 1),\n",
" (0.999991421309669, 1),\n",
" (0.999989619945967, 1),\n",
" (0.999989488185467, 1),\n",
" (0.999979937474046, 1),\n",
" (0.999976591802924, 1),\n",
" (0.999971356358266, 1)]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT proba as probability, mature\n",
" FROM abalone_logreg_test_proba\n",
" ORDER BY probability DESC\n",
" LIMIT 10;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculate the area under the ROC"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>area_under_roc</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.89729162423337180618758011406069658496850</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('0.89729162423337180618758011406069658496850'),)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"drop table if exists abalone_logreg_test_auc;\n",
"SELECT\n",
"madlib.area_under_roc(\n",
" 'abalone_logreg_test_proba', -- table_in, \n",
" 'abalone_logreg_test_auc', --table_out,\n",
" 'proba', -- prediction_col, \n",
" 'mature' --observed_col, \n",
") as result\n",
";\n",
"SELECT *\n",
"FROM abalone_logreg_test_auc\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1253 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>predicted</th>\n",
" <th>mature</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 1),\n",
" (1, 1),\n",
" (0, 1),\n",
" (1, 0),\n",
" (1, 1),\n",
" (1, 1),\n",
" (1, 1),\n",
" (1, 1),\n",
" (1, 1),\n",
" (0, 0)]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_logreg_test_predict ;\n",
"CREATE TABLE abalone_logreg_test_predict\n",
"AS\n",
"SELECT\n",
" (proba >= 0.5)::integer as predicted,\n",
" mature\n",
"FROM abalone_logreg_test_proba\n",
";\n",
"SELECT * \n",
" FROM abalone_logreg_test_predict\n",
" LIMIT 10\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>row_id</th>\n",
" <th>Mature</th>\n",
" <th>Predicted: Not Mature</th>\n",
" <th>Mature_1</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>306</td>\n",
" <td>123</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>96</td>\n",
" <td>728</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1L, 0, Decimal('306'), Decimal('123')),\n",
" (2L, 1, Decimal('96'), Decimal('728'))]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_logreg_test_conf_matrix;\n",
"SELECT\n",
"madlib.confusion_matrix(\n",
" 'abalone_logreg_test_predict', -- table_in\n",
" 'abalone_logreg_test_conf_matrix', -- table_out\n",
" 'predicted', --prediction_col\n",
" 'mature' --observation_col\n",
")\n",
";\n",
"SELECT --- display the confusion matrix\n",
" row_id,\n",
" class as \"Mature\",\n",
" confusion_arr[1] as \"Predicted: Not Mature\",\n",
" confusion_arr[2] as \"Mature\"\n",
"FROM abalone_logreg_test_conf_matrix\n",
"ORDER BY row_id\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get ROC values (thresholds, true-positives, false-positives)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1253</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1253L,)]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_logreg_test_binary_metrics ;\n",
"SELECT\n",
"madlib.binary_classifier(\n",
" 'abalone_logreg_test_proba', -- table_in\n",
" 'abalone_logreg_test_binary_metrics', -- table_out\n",
" 'proba', --prediction_col\n",
" 'mature' --observation_col\n",
")\n",
";\n",
"select count(*) from abalone_logreg_test_binary_metrics\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"19 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>threshold</th>\n",
" <th>tp</th>\n",
" <th>fp</th>\n",
" <th>fn</th>\n",
" <th>tn</th>\n",
" <th>tpr</th>\n",
" <th>tnr</th>\n",
" <th>ppv</th>\n",
" <th>npv</th>\n",
" <th>fpr</th>\n",
" <th>fdr</th>\n",
" <th>fnr</th>\n",
" <th>acc</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.480496919575</td>\n",
" <td>735</td>\n",
" <td>130</td>\n",
" <td>89</td>\n",
" <td>299</td>\n",
" <td>0.891990291262</td>\n",
" <td>0.69696969697</td>\n",
" <td>0.849710982659</td>\n",
" <td>0.770618556701</td>\n",
" <td>0.30303030303</td>\n",
" <td>0.150289017341</td>\n",
" <td>0.108009708738</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.87033747779751332149</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.481956373088</td>\n",
" <td>734</td>\n",
" <td>130</td>\n",
" <td>90</td>\n",
" <td>299</td>\n",
" <td>0.890776699029</td>\n",
" <td>0.69696969697</td>\n",
" <td>0.849537037037</td>\n",
" <td>0.768637532134</td>\n",
" <td>0.30303030303</td>\n",
" <td>0.150462962963</td>\n",
" <td>0.109223300971</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86966824644549763033</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.484870083561</td>\n",
" <td>734</td>\n",
" <td>129</td>\n",
" <td>90</td>\n",
" <td>300</td>\n",
" <td>0.890776699029</td>\n",
" <td>0.699300699301</td>\n",
" <td>0.850521436848</td>\n",
" <td>0.769230769231</td>\n",
" <td>0.300699300699</td>\n",
" <td>0.149478563152</td>\n",
" <td>0.109223300971</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.87018375815056312982</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.485451087833</td>\n",
" <td>734</td>\n",
" <td>128</td>\n",
" <td>90</td>\n",
" <td>301</td>\n",
" <td>0.890776699029</td>\n",
" <td>0.701631701632</td>\n",
" <td>0.85150812065</td>\n",
" <td>0.769820971867</td>\n",
" <td>0.298368298368</td>\n",
" <td>0.14849187935</td>\n",
" <td>0.109223300971</td>\n",
" <td>0.826017557861</td>\n",
" <td>0.87069988137603795967</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.485865382954</td>\n",
" <td>733</td>\n",
" <td>128</td>\n",
" <td>91</td>\n",
" <td>301</td>\n",
" <td>0.889563106796</td>\n",
" <td>0.701631701632</td>\n",
" <td>0.851335656214</td>\n",
" <td>0.767857142857</td>\n",
" <td>0.298368298368</td>\n",
" <td>0.148664343786</td>\n",
" <td>0.110436893204</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.87002967359050445104</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.487008367123</td>\n",
" <td>732</td>\n",
" <td>128</td>\n",
" <td>92</td>\n",
" <td>301</td>\n",
" <td>0.888349514563</td>\n",
" <td>0.701631701632</td>\n",
" <td>0.851162790698</td>\n",
" <td>0.765903307888</td>\n",
" <td>0.298368298368</td>\n",
" <td>0.148837209302</td>\n",
" <td>0.111650485437</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86935866983372921615</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.490888785527</td>\n",
" <td>732</td>\n",
" <td>127</td>\n",
" <td>92</td>\n",
" <td>302</td>\n",
" <td>0.888349514563</td>\n",
" <td>0.703962703963</td>\n",
" <td>0.852153667055</td>\n",
" <td>0.766497461929</td>\n",
" <td>0.296037296037</td>\n",
" <td>0.147846332945</td>\n",
" <td>0.111650485437</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86987522281639928699</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.491507015169</td>\n",
" <td>731</td>\n",
" <td>127</td>\n",
" <td>93</td>\n",
" <td>302</td>\n",
" <td>0.88713592233</td>\n",
" <td>0.703962703963</td>\n",
" <td>0.851981351981</td>\n",
" <td>0.764556962025</td>\n",
" <td>0.296037296037</td>\n",
" <td>0.148018648019</td>\n",
" <td>0.11286407767</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86920332936979785969</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.49466862411</td>\n",
" <td>731</td>\n",
" <td>126</td>\n",
" <td>93</td>\n",
" <td>303</td>\n",
" <td>0.88713592233</td>\n",
" <td>0.706293706294</td>\n",
" <td>0.852975495916</td>\n",
" <td>0.765151515152</td>\n",
" <td>0.293706293706</td>\n",
" <td>0.147024504084</td>\n",
" <td>0.11286407767</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86972040452111838192</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.496203693043</td>\n",
" <td>731</td>\n",
" <td>125</td>\n",
" <td>93</td>\n",
" <td>304</td>\n",
" <td>0.88713592233</td>\n",
" <td>0.708624708625</td>\n",
" <td>0.853971962617</td>\n",
" <td>0.765743073048</td>\n",
" <td>0.291375291375</td>\n",
" <td>0.146028037383</td>\n",
" <td>0.11286407767</td>\n",
" <td>0.826017557861</td>\n",
" <td>0.87023809523809523810</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.497656616358</td>\n",
" <td>730</td>\n",
" <td>125</td>\n",
" <td>94</td>\n",
" <td>304</td>\n",
" <td>0.885922330097</td>\n",
" <td>0.708624708625</td>\n",
" <td>0.853801169591</td>\n",
" <td>0.763819095477</td>\n",
" <td>0.291375291375</td>\n",
" <td>0.146198830409</td>\n",
" <td>0.114077669903</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86956521739130434783</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.499008921623</td>\n",
" <td>729</td>\n",
" <td>125</td>\n",
" <td>95</td>\n",
" <td>304</td>\n",
" <td>0.884708737864</td>\n",
" <td>0.708624708625</td>\n",
" <td>0.853629976581</td>\n",
" <td>0.761904761905</td>\n",
" <td>0.291375291375</td>\n",
" <td>0.146370023419</td>\n",
" <td>0.115291262136</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86889153754469606675</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.499074380654</td>\n",
" <td>728</td>\n",
" <td>125</td>\n",
" <td>96</td>\n",
" <td>304</td>\n",
" <td>0.883495145631</td>\n",
" <td>0.708624708625</td>\n",
" <td>0.853458382181</td>\n",
" <td>0.76</td>\n",
" <td>0.291375291375</td>\n",
" <td>0.146541617819</td>\n",
" <td>0.116504854369</td>\n",
" <td>0.82362330407</td>\n",
" <td>0.86821705426356589147</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.499863656654</td>\n",
" <td>728</td>\n",
" <td>124</td>\n",
" <td>96</td>\n",
" <td>305</td>\n",
" <td>0.883495145631</td>\n",
" <td>0.710955710956</td>\n",
" <td>0.854460093897</td>\n",
" <td>0.760598503741</td>\n",
" <td>0.289044289044</td>\n",
" <td>0.145539906103</td>\n",
" <td>0.116504854369</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86873508353221957041</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.500382820041</td>\n",
" <td>728</td>\n",
" <td>123</td>\n",
" <td>96</td>\n",
" <td>306</td>\n",
" <td>0.883495145631</td>\n",
" <td>0.713286713287</td>\n",
" <td>0.855464159812</td>\n",
" <td>0.761194029851</td>\n",
" <td>0.286713286713</td>\n",
" <td>0.144535840188</td>\n",
" <td>0.116504854369</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86925373134328358209</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.504665625386</td>\n",
" <td>727</td>\n",
" <td>123</td>\n",
" <td>97</td>\n",
" <td>306</td>\n",
" <td>0.882281553398</td>\n",
" <td>0.713286713287</td>\n",
" <td>0.855294117647</td>\n",
" <td>0.759305210918</td>\n",
" <td>0.286713286713</td>\n",
" <td>0.144705882353</td>\n",
" <td>0.117718446602</td>\n",
" <td>0.824421388667</td>\n",
" <td>0.86857825567502986858</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.509291045366</td>\n",
" <td>727</td>\n",
" <td>122</td>\n",
" <td>97</td>\n",
" <td>307</td>\n",
" <td>0.882281553398</td>\n",
" <td>0.715617715618</td>\n",
" <td>0.856301531213</td>\n",
" <td>0.759900990099</td>\n",
" <td>0.284382284382</td>\n",
" <td>0.143698468787</td>\n",
" <td>0.117718446602</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86909742976688583383</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.516539096279</td>\n",
" <td>727</td>\n",
" <td>121</td>\n",
" <td>97</td>\n",
" <td>308</td>\n",
" <td>0.882281553398</td>\n",
" <td>0.717948717949</td>\n",
" <td>0.857311320755</td>\n",
" <td>0.76049382716</td>\n",
" <td>0.282051282051</td>\n",
" <td>0.142688679245</td>\n",
" <td>0.117718446602</td>\n",
" <td>0.826017557861</td>\n",
" <td>0.86961722488038277512</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.517108140708</td>\n",
" <td>726</td>\n",
" <td>121</td>\n",
" <td>98</td>\n",
" <td>308</td>\n",
" <td>0.881067961165</td>\n",
" <td>0.717948717949</td>\n",
" <td>0.857142857143</td>\n",
" <td>0.758620689655</td>\n",
" <td>0.282051282051</td>\n",
" <td>0.142857142857</td>\n",
" <td>0.118932038835</td>\n",
" <td>0.825219473264</td>\n",
" <td>0.86894075403949730700</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.480496919574974, Decimal('735'), Decimal('130'), Decimal('89'), Decimal('299'), 0.891990291262136, 0.696969696969697, 0.84971098265896, 0.770618556701031, 0.303030303030303, 0.15028901734104, 0.108009708737864, 0.825219473264166, Decimal('0.87033747779751332149')),\n",
" (0.481956373087644, Decimal('734'), Decimal('130'), Decimal('90'), Decimal('299'), 0.890776699029126, 0.696969696969697, 0.849537037037037, 0.768637532133676, 0.303030303030303, 0.150462962962963, 0.109223300970874, 0.824421388667199, Decimal('0.86966824644549763033')),\n",
" (0.484870083560699, Decimal('734'), Decimal('129'), Decimal('90'), Decimal('300'), 0.890776699029126, 0.699300699300699, 0.850521436848204, 0.769230769230769, 0.300699300699301, 0.149478563151796, 0.109223300970874, 0.825219473264166, Decimal('0.87018375815056312982')),\n",
" (0.485451087833237, Decimal('734'), Decimal('128'), Decimal('90'), Decimal('301'), 0.890776699029126, 0.701631701631702, 0.851508120649652, 0.769820971867008, 0.298368298368298, 0.148491879350348, 0.109223300970874, 0.826017557861133, Decimal('0.87069988137603795967')),\n",
" (0.485865382953973, Decimal('733'), Decimal('128'), Decimal('91'), Decimal('301'), 0.889563106796116, 0.701631701631702, 0.851335656213705, 0.767857142857143, 0.298368298368298, 0.148664343786295, 0.110436893203884, 0.825219473264166, Decimal('0.87002967359050445104')),\n",
" (0.487008367123115, Decimal('732'), Decimal('128'), Decimal('92'), Decimal('301'), 0.888349514563107, 0.701631701631702, 0.851162790697674, 0.765903307888041, 0.298368298368298, 0.148837209302326, 0.111650485436893, 0.824421388667199, Decimal('0.86935866983372921615')),\n",
" (0.490888785527338, Decimal('732'), Decimal('127'), Decimal('92'), Decimal('302'), 0.888349514563107, 0.703962703962704, 0.852153667054715, 0.766497461928934, 0.296037296037296, 0.147846332945285, 0.111650485436893, 0.825219473264166, Decimal('0.86987522281639928699')),\n",
" (0.491507015168975, Decimal('731'), Decimal('127'), Decimal('93'), Decimal('302'), 0.887135922330097, 0.703962703962704, 0.851981351981352, 0.764556962025316, 0.296037296037296, 0.148018648018648, 0.112864077669903, 0.824421388667199, Decimal('0.86920332936979785969')),\n",
" (0.494668624110227, Decimal('731'), Decimal('126'), Decimal('93'), Decimal('303'), 0.887135922330097, 0.706293706293706, 0.852975495915986, 0.765151515151515, 0.293706293706294, 0.147024504084014, 0.112864077669903, 0.825219473264166, Decimal('0.86972040452111838192')),\n",
" (0.49620369304347, Decimal('731'), Decimal('125'), Decimal('93'), Decimal('304'), 0.887135922330097, 0.708624708624709, 0.853971962616822, 0.765743073047859, 0.291375291375291, 0.146028037383178, 0.112864077669903, 0.826017557861133, Decimal('0.87023809523809523810')),\n",
" (0.497656616357738, Decimal('730'), Decimal('125'), Decimal('94'), Decimal('304'), 0.885922330097087, 0.708624708624709, 0.853801169590643, 0.763819095477387, 0.291375291375291, 0.146198830409357, 0.114077669902913, 0.825219473264166, Decimal('0.86956521739130434783')),\n",
" (0.499008921622982, Decimal('729'), Decimal('125'), Decimal('95'), Decimal('304'), 0.884708737864078, 0.708624708624709, 0.853629976580796, 0.761904761904762, 0.291375291375291, 0.146370023419204, 0.115291262135922, 0.824421388667199, Decimal('0.86889153754469606675')),\n",
" (0.499074380653872, Decimal('728'), Decimal('125'), Decimal('96'), Decimal('304'), 0.883495145631068, 0.708624708624709, 0.853458382180539, 0.76, 0.291375291375291, 0.146541617819461, 0.116504854368932, 0.823623304070231, Decimal('0.86821705426356589147')),\n",
" (0.499863656653757, Decimal('728'), Decimal('124'), Decimal('96'), Decimal('305'), 0.883495145631068, 0.710955710955711, 0.854460093896714, 0.760598503740648, 0.289044289044289, 0.145539906103286, 0.116504854368932, 0.824421388667199, Decimal('0.86873508353221957041')),\n",
" (0.500382820040981, Decimal('728'), Decimal('123'), Decimal('96'), Decimal('306'), 0.883495145631068, 0.713286713286713, 0.855464159811986, 0.761194029850746, 0.286713286713287, 0.144535840188014, 0.116504854368932, 0.825219473264166, Decimal('0.86925373134328358209')),\n",
" (0.504665625386371, Decimal('727'), Decimal('123'), Decimal('97'), Decimal('306'), 0.882281553398058, 0.713286713286713, 0.855294117647059, 0.759305210918114, 0.286713286713287, 0.144705882352941, 0.117718446601942, 0.824421388667199, Decimal('0.86857825567502986858')),\n",
" (0.509291045365549, Decimal('727'), Decimal('122'), Decimal('97'), Decimal('307'), 0.882281553398058, 0.715617715617716, 0.856301531213192, 0.75990099009901, 0.284382284382284, 0.143698468786808, 0.117718446601942, 0.825219473264166, Decimal('0.86909742976688583383')),\n",
" (0.516539096278762, Decimal('727'), Decimal('121'), Decimal('97'), Decimal('308'), 0.882281553398058, 0.717948717948718, 0.857311320754717, 0.760493827160494, 0.282051282051282, 0.142688679245283, 0.117718446601942, 0.826017557861133, Decimal('0.86961722488038277512')),\n",
" (0.517108140708033, Decimal('726'), Decimal('121'), Decimal('98'), Decimal('308'), 0.881067961165049, 0.717948717948718, 0.857142857142857, 0.758620689655172, 0.282051282051282, 0.142857142857143, 0.118932038834951, 0.825219473264166, Decimal('0.86894075403949730700'))]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT *\n",
"FROM abalone_logreg_test_binary_metrics\n",
"WHERE \n",
" --round(threshold::numeric, 1) = 0.5\n",
" threshold >= 0.48 AND\n",
" threshold <= 0.52\n",
"ORDER BY threshold;"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1253 rows affected.\n"
]
}
],
"source": [
"logreg_metrics = %sql SELECT * FROM abalone_logreg_test_binary_metrics ORDER BY threshold;\n",
"logreg_metrics = logreg_metrics.DataFrame();"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"logreg_metrics.plot('fpr', 'tpr',xlim=(0.,1.),ylim=(0.,1.));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"forest\"></a>\n",
"## 5b. Random Forest "
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>forest_train</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_model;\n",
"DROP TABLE IF EXISTS abalone_rf_model_group;\n",
"DROP TABLE IF EXISTS abalone_rf_model_summary;\n",
"SELECT\n",
"madlib.forest_train(\n",
" 'abalone_classif_train', -- training_table_name\n",
" 'abalone_rf_model', -- output_table_name\n",
" 'id', -- id_col_name\n",
" 'mature', -- dependent_variable\n",
" 'length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,sex_f,sex_m', -- list_of_features\n",
" NULL, -- list_of_features_to_exclude\n",
" NULL, -- grouping_columns\n",
" 10 -- number of trees\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View one of the trees in the forest. You could also view the tree in dot format, though for random forest tree visualization is less relevant."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>get_tree</th>\n",
" </tr>\n",
" <tr>\n",
" <td>digraph \"Classification tree for abalone_classif_train\" {<br>\"0\" [label=\"length &lt;= 0.495\\n impurity = 0.453016\\n samples = 3011\\n value = [1044 1967]\\n class = 1\", shape=ellipse];<br>\"0\" -&gt; \"1\"[label=\"yes\"];<br>\"0\" -&gt; \"2\"[label=\"no\"];<br>\"1\" [label=\"diameter &lt;= 0.345\\n impurity = 0.430643\\n samples = 1125\\n value = [772 353]\\n class = 0\", shape=ellipse];<br>\"1\" -&gt; \"3\"[label=\"yes\"];<br>\"1\" -&gt; \"4\"[label=\"no\"];<br>\"2\" [label=\"shell_weight &lt;= 0.2335\\n impurity = 0.246842\\n samples = 1886\\n value = [ 272 1614]\\n class = 1\", shape=ellipse];<br>\"2\" -&gt; \"5\"[label=\"yes\"];<br>\"2\" -&gt; \"6\"[label=\"no\"];<br>\"3\" [label=\"sex_m &lt;= 0\\n impurity = 0.293184\\n samples = 751\\n value = [617 134]\\n class = 0\", shape=ellipse];<br>\"3\" -&gt; \"7\"[label=\"yes\"];<br>\"3\" -&gt; \"8\"[label=\"no\"];<br>\"4\" [label=\"shell_weight &lt;= 0.1655\\n impurity = 0.485358\\n samples = 374\\n value = [155 219]\\n class = 1\", shape=ellipse];<br>\"4\" -&gt; \"9\"[label=\"yes\"];<br>\"4\" -&gt; \"10\"[label=\"no\"];<br>\"5\" [label=\"sex_f &lt;= 0\\n impurity = 0.482952\\n samples = 417\\n value = [170 247]\\n class = 1\", shape=ellipse];<br>\"5\" -&gt; \"11\"[label=\"yes\"];<br>\"5\" -&gt; \"12\"[label=\"no\"];<br>\"6\" [label=\"shell_weight &lt;= 0.2725\\n impurity = 0.129228\\n samples = 1469\\n value = [ 102 1367]\\n class = 1\", shape=ellipse];<br>\"6\" -&gt; \"13\"[label=\"yes\"];<br>\"6\" -&gt; \"14\"[label=\"no\"];<br>\"7\" [label=\"sex_f &lt;= 0\\n impurity = 0.209691\\n samples = 605\\n value = [533 72]\\n class = 0\", shape=ellipse];<br>\"7\" -&gt; \"15\"[label=\"yes\"];<br>\"15\" [label=\"0\\n impurity = 0.122069\\n samples = 536\\n value = [501 35]\",shape=box];<br>\"7\" -&gt; \"16\"[label=\"no\"];<br>\"8\" [label=\"length &lt;= 0.29\\n impurity = 0.488647\\n samples = 146\\n value = [84 62]\\n class = 0\", shape=ellipse];<br>\"8\" -&gt; \"17\"[label=\"yes\"];<br>\"17\" [label=\"0\\n impurity = 0\\n samples = 28\\n value = [28 0]\",shape=box];<br>\"8\" -&gt; \"18\"[label=\"no\"];<br>\"9\" [label=\"shucked_weight &lt;= 0.222\\n impurity = 0.498485\\n samples = 218\\n value = [115 103]\\n class = 0\", shape=ellipse];<br>\"9\" -&gt; \"19\"[label=\"yes\"];<br>\"9\" -&gt; \"20\"[label=\"no\"];<br>\"10\" [label=\"height &lt;= 0.15\\n impurity = 0.381328\\n samples = 156\\n value = [ 40 116]\\n class = 1\", shape=ellipse];<br>\"10\" -&gt; \"21\"[label=\"yes\"];<br>\"10\" -&gt; \"22\"[label=\"no\"];<br>\"22\" [label=\"0\\n impurity = 0.444444\\n samples = 18\\n value = [12 6]\",shape=box];<br>\"11\" [label=\"height &lt;= 0.125\\n impurity = 0.495403\\n samples = 292\\n value = [132 160]\\n class = 1\", shape=ellipse];<br>\"11\" -&gt; \"23\"[label=\"yes\"];<br>\"11\" -&gt; \"24\"[label=\"no\"];<br>\"12\" [label=\"length &lt;= 0.53\\n impurity = 0.423168\\n samples = 125\\n value = [38 87]\\n class = 1\", shape=ellipse];<br>\"12\" -&gt; \"25\"[label=\"yes\"];<br>\"12\" -&gt; \"26\"[label=\"no\"];<br>\"13\" [label=\"length &lt;= 0.565\\n impurity = 0.289085\\n samples = 291\\n value = [ 51 240]\\n class = 1\", shape=ellipse];<br>\"13\" -&gt; \"27\"[label=\"yes\"];<br>\"13\" -&gt; \"28\"[label=\"no\"];<br>\"14\" [label=\"shell_weight &lt;= 0.3645\\n impurity = 0.0828387\\n samples = 1178\\n value = [ 51 1127]\\n class = 1\", shape=ellipse];<br>\"14\" -&gt; \"29\"[label=\"yes\"];<br>\"14\" -&gt; \"30\"[label=\"no\"];<br>\"16\" [label=\"diameter &lt;= 0.29\\n impurity = 0.497375\\n samples = 69\\n value = [32 37]\\n class = 1\", shape=ellipse];<br>\"16\" -&gt; \"33\"[label=\"yes\"];<br>\"33\" [label=\"0\\n impurity = 0.32\\n samples = 20\\n value = [16 4]\",shape=box];<br>\"16\" -&gt; \"34\"[label=\"no\"];<br>\"18\" [label=\"diameter &lt;= 0.34\\n impurity = 0.498707\\n samples = 118\\n value = [56 62]\\n class = 1\", shape=ellipse];<br>\"18\" -&gt; \"37\"[label=\"yes\"];<br>\"18\" -&gt; \"38\"[label=\"no\"];<br>\"38\" [label=\"0\\n impurity = 0.21875\\n samples = 8\\n value = [7 1]\",shape=box];<br>\"19\" [label=\"shell_weight &lt;= 0.125\\n impurity = 0.472126\\n samples = 144\\n value = [55 89]\\n class = 1\", shape=ellipse];<br>\"19\" -&gt; \"39\"[label=\"yes\"];<br>\"39\" [label=\"0\\n impurity = 0.444444\\n samples = 27\\n value = [18 9]\",shape=box];<br>\"19\" -&gt; \"40\"[label=\"no\"];<br>\"20\" [label=\"height &lt;= 0.13\\n impurity = 0.306793\\n samples = 74\\n value = [60 14]\\n class = 0\", shape=ellipse];<br>\"20\" -&gt; \"41\"[label=\"yes\"];<br>\"20\" -&gt; \"42\"[label=\"no\"];<br>\"42\" [label=\"0\\n impurity = 0\\n samples = 13\\n value = [13 0]\",shape=box];<br>\"21\" [label=\"shucked_weight &lt;= 0.3015\\n impurity = 0.323461\\n samples = 138\\n value = [ 28 110]\\n class = 1\", shape=ellipse];<br>\"21\" -&gt; \"43\"[label=\"yes\"];<br>\"21\" -&gt; \"44\"[label=\"no\"];<br>\"44\" [label=\"0\\n impurity = 0.197531\\n samples = 9\\n value = [8 1]\",shape=box];<br>\"23\" [label=\"shucked_weight &lt;= 0.289\\n impurity = 0.444444\\n samples = 72\\n value = [48 24]\\n class = 0\", shape=ellipse];<br>\"23\" -&gt; \"47\"[label=\"yes\"];<br>\"47\" [label=\"0\\n impurity = 0.277778\\n samples = 42\\n value = [35 7]\",shape=box];<br>\"23\" -&gt; \"48\"[label=\"no\"];<br>\"48\" [label=\"1\\n impurity = 0.491111\\n samples = 30\\n value = [13 17]\",shape=box];<br>\"24\" [label=\"diameter &lt;= 0.46\\n impurity = 0.472066\\n samples = 220\\n value = [ 84 136]\\n class = 1\", shape=ellipse];<br>\"24\" -&gt; \"49\"[label=\"yes\"];<br>\"24\" -&gt; \"50\"[label=\"no\"];<br>\"50\" [label=\"0\\n impurity = 0.244898\\n samples = 7\\n value = [6 1]\",shape=box];<br>\"25\" [label=\"shell_weight &lt;= 0.1765\\n impurity = 0.32\\n samples = 65\\n value = [13 52]\\n class = 1\", shape=ellipse];<br>\"25\" -&gt; \"51\"[label=\"yes\"];<br>\"51\" [label=\"0\\n impurity = 0.345679\\n samples = 9\\n value = [7 2]\",shape=box];<br>\"25\" -&gt; \"52\"[label=\"no\"];<br>\"52\" [label=\"1\\n impurity = 0.191327\\n samples = 56\\n value = [ 6 50]\",shape=box];<br>\"26\" [label=\"shucked_weight &lt;= 0.3875\\n impurity = 0.486111\\n samples = 60\\n value = [25 35]\\n class = 1\", shape=ellipse];<br>\"26\" -&gt; \"53\"[label=\"yes\"];<br>\"26\" -&gt; \"54\"[label=\"no\"];<br>\"54\" [label=\"0\\n impurity = 0.4352\\n samples = 25\\n value = [17 8]\",shape=box];<br>\"27\" [label=\"shucked_weight &lt;= 0.435\\n impurity = 0.183494\\n samples = 137\\n value = [ 14 123]\\n class = 1\", shape=ellipse];<br>\"27\" -&gt; \"55\"[label=\"yes\"];<br>\"27\" -&gt; \"56\"[label=\"no\"];<br>\"56\" [label=\"1\\n impurity = 0.459184\\n samples = 14\\n value = [5 9]\",shape=box];<br>\"28\" [label=\"shell_weight &lt;= 0.265\\n impurity = 0.36507\\n samples = 154\\n value = [ 37 117]\\n class = 1\", shape=ellipse];<br>\"28\" -&gt; \"57\"[label=\"yes\"];<br>\"28\" -&gt; \"58\"[label=\"no\"];<br>\"29\" [label=\"viscera_weight &lt;= 0.289\\n impurity = 0.117514\\n samples = 670\\n value = [ 42 628]\\n class = 1\", shape=ellipse];<br>\"29\" -&gt; \"59\"[label=\"yes\"];<br>\"29\" -&gt; \"60\"[label=\"no\"];<br>\"30\" [label=\"diameter &lt;= 0.48\\n impurity = 0.0348053\\n samples = 508\\n value = [ 9 499]\\n class = 1\", shape=ellipse];<br>\"30\" -&gt; \"61\"[label=\"yes\"];<br>\"30\" -&gt; \"62\"[label=\"no\"];<br>\"34\" [label=\"whole_weight &lt;= 0.3855\\n impurity = 0.439817\\n samples = 49\\n value = [16 33]\\n class = 1\", shape=ellipse];<br>\"34\" -&gt; \"69\"[label=\"yes\"];<br>\"69\" [label=\"1\\n impurity = 0.349636\\n samples = 31\\n value = [ 7 24]\",shape=box];<br>\"34\" -&gt; \"70\"[label=\"no\"];<br>\"70\" [label=\"0\\n impurity = 0.5\\n samples = 18\\n value = [9 9]\",shape=box];<br>\"37\" [label=\"shucked_weight &lt;= 0.194\\n impurity = 0.49405\\n samples = 110\\n value = [49 61]\\n class = 1\", shape=ellipse];<br>\"37\" -&gt; \"75\"[label=\"yes\"];<br>\"37\" -&gt; \"76\"[label=\"no\"];<br>\"76\" [label=\"0\\n impurity = 0.18\\n samples = 10\\n value = [9 1]\",shape=box];<br>\"40\" [label=\"length &lt;= 0.485\\n impurity = 0.432464\\n samples = 117\\n value = [37 80]\\n class = 1\", shape=ellipse];<br>\"40\" -&gt; \"81\"[label=\"yes\"];<br>\"40\" -&gt; \"82\"[label=\"no\"];<br>\"82\" [label=\"0\\n impurity = 0\\n samples = 6\\n value = [6 0]\",shape=box];<br>\"41\" [label=\"shell_weight &lt;= 0.1525\\n impurity = 0.353668\\n samples = 61\\n value = [47 14]\\n class = 0\", shape=ellipse];<br>\"41\" -&gt; \"83\"[label=\"yes\"];<br>\"83\" [label=\"0\\n impurity = 0.264514\\n samples = 51\\n value = [43 8]\",shape=box];<br>\"41\" -&gt; \"84\"[label=\"no\"];<br>\"84\" [label=\"1\\n impurity = 0.48\\n samples = 10\\n value = [4 6]\",shape=box];<br>\"43\" [label=\"shell_weight &lt;= 0.203\\n impurity = 0.262003\\n samples = 129\\n value = [ 20 109]\\n class = 1\", shape=ellipse];<br>\"43\" -&gt; \"87\"[label=\"yes\"];<br>\"43\" -&gt; \"88\"[label=\"no\"];<br>\"88\" [label=\"1\\n impurity = 0\\n samples = 26\\n value = [ 0 26]\",shape=box];<br>\"49\" [label=\"shell_weight &lt;= 0.1765\\n impurity = 0.464194\\n samples = 213\\n value = [ 78 135]\\n class = 1\", shape=ellipse];<br>\"49\" -&gt; \"99\"[label=\"yes\"];<br>\"49\" -&gt; \"100\"[label=\"no\"];<br>\"53\" [label=\"shell_weight &lt;= 0.206\\n impurity = 0.352653\\n samples = 35\\n value = [ 8 27]\\n class = 1\", shape=ellipse];<br>\"53\" -&gt; \"107\"[label=\"yes\"];<br>\"107\" [label=\"1\\n impurity = 0.492188\\n samples = 16\\n value = [7 9]\",shape=box];<br>\"53\" -&gt; \"108\"[label=\"no\"];<br>\"108\" [label=\"1\\n impurity = 0.099723\\n samples = 19\\n value = [ 1 18]\",shape=box];<br>\"55\" [label=\"sex_f &lt;= 0\\n impurity = 0.135634\\n samples = 123\\n value = [ 9 114]\\n class = 1\", shape=ellipse];<br>\"55\" -&gt; \"111\"[label=\"yes\"];<br>\"55\" -&gt; \"112\"[label=\"no\"];<br>\"112\" [label=\"1\\n impurity = 0\\n samples = 42\\n value = [ 0 42]\",shape=box];<br>\"57\" [label=\"viscera_weight &lt;= 0.2225\\n impurity = 0.388198\\n samples = 129\\n value = [34 95]\\n class = 1\", shape=ellipse];<br>\"57\" -&gt; \"115\"[label=\"yes\"];<br>\"57\" -&gt; \"116\"[label=\"no\"];<br>\"58\" [label=\"viscera_weight &lt;= 0.187\\n impurity = 0.2112\\n samples = 25\\n value = [ 3 22]\\n class = 1\", shape=ellipse];<br>\"58\" -&gt; \"117\"[label=\"yes\"];<br>\"117\" [label=\"1\\n impurity = 0\\n samples = 8\\n value = [0 8]\",shape=box];<br>\"58\" -&gt; \"118\"[label=\"no\"];<br>\"118\" [label=\"1\\n impurity = 0.290657\\n samples = 17\\n value = [ 3 14]\",shape=box];<br>\"59\" [label=\"shucked_weight &lt;= 0.4015\\n impurity = 0.0881506\\n samples = 541\\n value = [ 25 516]\\n class = 1\", shape=ellipse];<br>\"59\" -&gt; \"119\"[label=\"yes\"];<br>\"119\" [label=\"1\\n impurity = 0\\n samples = 166\\n value = [ 0 166]\",shape=box];<br>\"59\" -&gt; \"120\"[label=\"no\"];<br>\"120\" [label=\"1\\n impurity = 0.124444\\n samples = 375\\n value = [ 25 350]\",shape=box];<br>\"60\" [label=\"height &lt;= 0.175\\n impurity = 0.228832\\n samples = 129\\n value = [ 17 112]\\n class = 1\", shape=ellipse];<br>\"60\" -&gt; \"121\"[label=\"yes\"];<br>\"60\" -&gt; \"122\"[label=\"no\"];<br>\"61\" [label=\"viscera_weight &lt;= 0.24\\n impurity = 0.114952\\n samples = 49\\n value = [ 3 46]\\n class = 1\", shape=ellipse];<br>\"61\" -&gt; \"123\"[label=\"yes\"];<br>\"123\" [label=\"1\\n impurity = 0\\n samples = 25\\n value = [ 0 25]\",shape=box];<br>\"61\" -&gt; \"124\"[label=\"no\"];<br>\"62\" [label=\"shell_weight &lt;= 0.44\\n impurity = 0.025802\\n samples = 459\\n value = [ 6 453]\\n class = 1\", shape=ellipse];<br>\"62\" -&gt; \"125\"[label=\"yes\"];<br>\"62\" -&gt; \"126\"[label=\"no\"];<br>\"126\" [label=\"1\\n impurity = 0\\n samples = 195\\n value = [ 0 195]\",shape=box];<br>\"75\" [label=\"height &lt;= 0.115\\n impurity = 0.48\\n samples = 100\\n value = [40 60]\\n class = 1\", shape=ellipse];<br>\"75\" -&gt; \"151\"[label=\"yes\"];<br>\"75\" -&gt; \"152\"[label=\"no\"];<br>\"152\" [label=\"1\\n impurity = 0\\n samples = 20\\n value = [ 0 20]\",shape=box];<br>\"81\" [label=\"viscera_weight &lt;= 0.11\\n impurity = 0.402565\\n samples = 111\\n value = [31 80]\\n class = 1\", shape=ellipse];<br>\"81\" -&gt; \"163\"[label=\"yes\"];<br>\"81\" -&gt; \"164\"[label=\"no\"];<br>\"87\" [label=\"whole_weight &lt;= 0.5465\\n impurity = 0.312942\\n samples = 103\\n value = [20 83]\\n class = 1\", shape=ellipse];<br>\"87\" -&gt; \"175\"[label=\"yes\"];<br>\"87\" -&gt; \"176\"[label=\"no\"];<br>\"99\" [label=\"length &lt;= 0.5\\n impurity = 0.488522\\n samples = 33\\n value = [19 14]\\n class = 0\", shape=ellipse];<br>\"99\" -&gt; \"199\"[label=\"yes\"];<br>\"199\" [label=\"1\\n impurity = 0.277778\\n samples = 6\\n value = [1 5]\",shape=box];<br>\"99\" -&gt; \"200\"[label=\"no\"];<br>\"100\" [label=\"shucked_weight &lt;= 0.3325\\n impurity = 0.440679\\n samples = 180\\n value = [ 59 121]\\n class = 1\", shape=ellipse];<br>\"100\" -&gt; \"201\"[label=\"yes\"];<br>\"100\" -&gt; \"202\"[label=\"no\"];<br>\"111\" [label=\"shucked_weight &lt;= 0.3445\\n impurity = 0.197531\\n samples = 81\\n value = [ 9 72]\\n class = 1\", shape=ellipse];<br>\"111\" -&gt; \"223\"[label=\"yes\"];<br>\"111\" -&gt; \"224\"[label=\"no\"];<br>\"115\" [label=\"shell_weight &lt;= 0.26\\n impurity = 0.280654\\n samples = 77\\n value = [13 64]\\n class = 1\", shape=ellipse];<br>\"115\" -&gt; \"231\"[label=\"yes\"];<br>\"115\" -&gt; \"232\"[label=\"no\"];<br>\"232\" [label=\"1\\n impurity = 0.493827\\n samples = 9\\n value = [4 5]\",shape=box];<br>\"116\" [label=\"height &lt;= 0.145\\n impurity = 0.481509\\n samples = 52\\n value = [21 31]\\n class = 1\", shape=ellipse];<br>\"116\" -&gt; \"233\"[label=\"yes\"];<br>\"233\" [label=\"0\\n impurity = 0.375\\n samples = 16\\n value = [12 4]\",shape=box];<br>\"116\" -&gt; \"234\"[label=\"no\"];<br>\"234\" [label=\"1\\n impurity = 0.375\\n samples = 36\\n value = [ 9 27]\",shape=box];<br>\"121\" [label=\"whole_weight &lt;= 1.2085\\n impurity = 0.165289\\n samples = 99\\n value = [ 9 90]\\n class = 1\", shape=ellipse];<br>\"121\" -&gt; \"243\"[label=\"yes\"];<br>\"121\" -&gt; \"244\"[label=\"no\"];<br>\"122\" [label=\"diameter &lt;= 0.5\\n impurity = 0.391111\\n samples = 30\\n value = [ 8 22]\\n class = 1\", shape=ellipse];<br>\"122\" -&gt; \"245\"[label=\"yes\"];<br>\"122\" -&gt; \"246\"[label=\"no\"];<br>\"246\" [label=\"1\\n impurity = 0\\n samples = 8\\n value = [0 8]\",shape=box];<br>\"124\" [label=\"sex_m &lt;= 0\\n impurity = 0.21875\\n samples = 24\\n value = [ 3 21]\\n class = 1\", shape=ellipse];<br>\"124\" -&gt; \"249\"[label=\"yes\"];<br>\"249\" [label=\"1\\n impurity = 0.336735\\n samples = 14\\n value = [ 3 11]\",shape=box];<br>\"124\" -&gt; \"250\"[label=\"no\"];<br>\"250\" [label=\"1\\n impurity = 0\\n samples = 10\\n value = [ 0 10]\",shape=box];<br>\"125\" [label=\"shell_weight &lt;= 0.43\\n impurity = 0.0444215\\n samples = 264\\n value = [ 6 258]\\n class = 1\", shape=ellipse];<br>\"125\" -&gt; \"251\"[label=\"yes\"];<br>\"125\" -&gt; \"252\"[label=\"no\"];<br>\"151\" [label=\"height &lt;= 0.075\\n impurity = 0.5\\n samples = 80\\n value = [40 40]\\n class = 0\", shape=ellipse];<br>\"151\" -&gt; \"303\"[label=\"yes\"];<br>\"303\" [label=\"1\\n impurity = 0.345679\\n samples = 9\\n value = [2 7]\",shape=box];<br>\"151\" -&gt; \"304\"[label=\"no\"];<br>\"304\" [label=\"0\\n impurity = 0.49752\\n samples = 71\\n value = [38 33]\",shape=box];<br>\"163\" [label=\"length &lt;= 0.455\\n impurity = 0.469649\\n samples = 69\\n value = [26 43]\\n class = 1\", shape=ellipse];<br>\"163\" -&gt; \"327\"[label=\"yes\"];<br>\"163\" -&gt; \"328\"[label=\"no\"];<br>\"164\" [label=\"viscera_weight &lt;= 0.1205\\n impurity = 0.209751\\n samples = 42\\n value = [ 5 37]\\n class = 1\", shape=ellipse];<br>\"164\" -&gt; \"329\"[label=\"yes\"];<br>\"329\" [label=\"1\\n impurity = 0.33241\\n samples = 19\\n value = [ 4 15]\",shape=box];<br>\"164\" -&gt; \"330\"[label=\"no\"];<br>\"175\" [label=\"viscera_weight &lt;= 0.1015\\n impurity = 0.19438\\n samples = 55\\n value = [ 6 49]\\n class = 1\", shape=ellipse];<br>\"175\" -&gt; \"351\"[label=\"yes\"];<br>\"175\" -&gt; \"352\"[label=\"no\"];<br>\"352\" [label=\"1\\n impurity = 0\\n samples = 33\\n value = [ 0 33]\",shape=box];<br>\"176\" [label=\"whole_weight &lt;= 0.587\\n impurity = 0.413194\\n samples = 48\\n value = [14 34]\\n class = 1\", shape=ellipse];<br>\"176\" -&gt; \"353\"[label=\"yes\"];<br>\"353\" [label=\"1\\n impurity = 0.499055\\n samples = 23\\n value = [11 12]\",shape=box];<br>\"176\" -&gt; \"354\"[label=\"no\"];<br>\"200\" [label=\"diameter &lt;= 0.405\\n impurity = 0.444444\\n samples = 27\\n value = [18 9]\\n class = 0\", shape=ellipse];<br>\"200\" -&gt; \"401\"[label=\"yes\"];<br>\"401\" [label=\"0\\n impurity = 0.375\\n samples = 20\\n value = [15 5]\",shape=box];<br>\"200\" -&gt; \"402\"[label=\"no\"];<br>\"402\" [label=\"1\\n impurity = 0.489796\\n samples = 7\\n value = [3 4]\",shape=box];<br>\"201\" [label=\"shucked_weight &lt;= 0.264\\n impurity = 0.21643\\n samples = 81\\n value = [10 71]\\n class = 1\", shape=ellipse];<br>\"201\" -&gt; \"403\"[label=\"yes\"];<br>\"403\" [label=\"1\\n impurity = 0\\n samples = 35\\n value = [ 0 35]\",shape=box];<br>\"201\" -&gt; \"404\"[label=\"no\"];<br>\"202\" [label=\"whole_weight &lt;= 0.7775\\n impurity = 0.499949\\n samples = 99\\n value = [49 50]\\n class = 1\", shape=ellipse];<br>\"202\" -&gt; \"405\"[label=\"yes\"];<br>\"405\" [label=\"0\\n impurity = 0.424383\\n samples = 36\\n value = [25 11]\",shape=box];<br>\"202\" -&gt; \"406\"[label=\"no\"];<br>\"406\" [label=\"1\\n impurity = 0.471655\\n samples = 63\\n value = [24 39]\",shape=box];<br>\"223\" [label=\"diameter &lt;= 0.435\\n impurity = 0.282481\\n samples = 47\\n value = [ 8 39]\\n class = 1\", shape=ellipse];<br>\"223\" -&gt; \"447\"[label=\"yes\"];<br>\"447\" [label=\"1\\n impurity = 0.192841\\n samples = 37\\n value = [ 4 33]\",shape=box];<br>\"223\" -&gt; \"448\"[label=\"no\"];<br>\"448\" [label=\"1\\n impurity = 0.48\\n samples = 10\\n value = [4 6]\",shape=box];<br>\"224\" [label=\"whole_weight &lt;= 0.9\\n impurity = 0.0570934\\n samples = 34\\n value = [ 1 33]\\n class = 1\", shape=ellipse];<br>\"224\" -&gt; \"449\"[label=\"yes\"];<br>\"449\" [label=\"1\\n impurity = 0\\n samples = 26\\n value = [ 0 26]\",shape=box];<br>\"224\" -&gt; \"450\"[label=\"no\"];<br>\"450\" [label=\"1\\n impurity = 0.21875\\n samples = 8\\n value = [1 7]\",shape=box];<br>\"231\" [label=\"viscera_weight &lt;= 0.166\\n impurity = 0.229671\\n samples = 68\\n value = [ 9 59]\\n class = 1\", shape=ellipse];<br>\"231\" -&gt; \"463\"[label=\"yes\"];<br>\"231\" -&gt; \"464\"[label=\"no\"];<br>\"243\" [label=\"sex_m &lt;= 0\\n impurity = 0.352653\\n samples = 35\\n value = [ 8 27]\\n class = 1\", shape=ellipse];<br>\"243\" -&gt; \"487\"[label=\"yes\"];<br>\"487\" [label=\"1\\n impurity = 0.18\\n samples = 20\\n value = [ 2 18]\",shape=box];<br>\"243\" -&gt; \"488\"[label=\"no\"];<br>\"488\" [label=\"1\\n impurity = 0.48\\n samples = 15\\n value = [6 9]\",shape=box];<br>\"244\" [label=\"shell_weight &lt;= 0.339\\n impurity = 0.0307617\\n samples = 64\\n value = [ 1 63]\\n class = 1\", shape=ellipse];<br>\"244\" -&gt; \"489\"[label=\"yes\"];<br>\"489\" [label=\"1\\n impurity = 0.0644444\\n samples = 30\\n value = [ 1 29]\",shape=box];<br>\"244\" -&gt; \"490\"[label=\"no\"];<br>\"490\" [label=\"1\\n impurity = 0\\n samples = 34\\n value = [ 0 34]\",shape=box];<br>\"245\" [label=\"shucked_weight &lt;= 0.572\\n impurity = 0.46281\\n samples = 22\\n value = [ 8 14]\\n class = 1\", shape=ellipse];<br>\"245\" -&gt; \"491\"[label=\"yes\"];<br>\"491\" [label=\"1\\n impurity = 0\\n samples = 9\\n value = [0 9]\",shape=box];<br>\"245\" -&gt; \"492\"[label=\"no\"];<br>\"492\" [label=\"0\\n impurity = 0.473373\\n samples = 13\\n value = [8 5]\",shape=box];<br>\"251\" [label=\"whole_weight &lt;= 1.2085\\n impurity = 0.0331856\\n samples = 237\\n value = [ 4 233]\\n class = 1\", shape=ellipse];<br>\"251\" -&gt; \"503\"[label=\"yes\"];<br>\"503\" [label=\"1\\n impurity = 0.15879\\n samples = 23\\n value = [ 2 21]\",shape=box];<br>\"251\" -&gt; \"504\"[label=\"no\"];<br>\"252\" [label=\"viscera_weight &lt;= 0.3175\\n impurity = 0.137174\\n samples = 27\\n value = [ 2 25]\\n class = 1\", shape=ellipse];<br>\"252\" -&gt; \"505\"[label=\"yes\"];<br>\"505\" [label=\"1\\n impurity = 0\\n samples = 11\\n value = [ 0 11]\",shape=box];<br>\"252\" -&gt; \"506\"[label=\"no\"];<br>\"506\" [label=\"1\\n impurity = 0.21875\\n samples = 16\\n value = [ 2 14]\",shape=box];<br>\"327\" [label=\"whole_weight &lt;= 0.444\\n impurity = 0.389273\\n samples = 34\\n value = [ 9 25]\\n class = 1\", shape=ellipse];<br>\"327\" -&gt; \"655\"[label=\"yes\"];<br>\"327\" -&gt; \"656\"[label=\"no\"];<br>\"656\" [label=\"0\\n impurity = 0.473373\\n samples = 13\\n value = [8 5]\",shape=box];<br>\"328\" [label=\"shucked_weight &lt;= 0.2\\n impurity = 0.499592\\n samples = 35\\n value = [17 18]\\n class = 1\", shape=ellipse];<br>\"328\" -&gt; \"657\"[label=\"yes\"];<br>\"328\" -&gt; \"658\"[label=\"no\"];<br>\"658\" [label=\"0\\n impurity = 0.391111\\n samples = 15\\n value = [11 4]\",shape=box];<br>\"330\" [label=\"whole_weight &lt;= 0.514\\n impurity = 0.0831758\\n samples = 23\\n value = [ 1 22]\\n class = 1\", shape=ellipse];<br>\"330\" -&gt; \"661\"[label=\"yes\"];<br>\"661\" [label=\"1\\n impurity = 0\\n samples = 17\\n value = [ 0 17]\",shape=box];<br>\"330\" -&gt; \"662\"[label=\"no\"];<br>\"662\" [label=\"1\\n impurity = 0.277778\\n samples = 6\\n value = [1 5]\",shape=box];<br>\"351\" [label=\"viscera_weight &lt;= 0.087\\n impurity = 0.396694\\n samples = 22\\n value = [ 6 16]\\n class = 1\", shape=ellipse];<br>\"351\" -&gt; \"703\"[label=\"yes\"];<br>\"703\" [label=\"1\\n impurity = 0.142012\\n samples = 13\\n value = [ 1 12]\",shape=box];<br>\"351\" -&gt; \"704\"[label=\"no\"];<br>\"704\" [label=\"0\\n impurity = 0.493827\\n samples = 9\\n value = [5 4]\",shape=box];<br>\"354\" [label=\"diameter &lt;= 0.385\\n impurity = 0.2112\\n samples = 25\\n value = [ 3 22]\\n class = 1\", shape=ellipse];<br>\"354\" -&gt; \"709\"[label=\"yes\"];<br>\"709\" [label=\"1\\n impurity = 0\\n samples = 17\\n value = [ 0 17]\",shape=box];<br>\"354\" -&gt; \"710\"[label=\"no\"];<br>\"710\" [label=\"1\\n impurity = 0.46875\\n samples = 8\\n value = [3 5]\",shape=box];<br>\"404\" [label=\"length &lt;= 0.54\\n impurity = 0.340265\\n samples = 46\\n value = [10 36]\\n class = 1\", shape=ellipse];<br>\"404\" -&gt; \"809\"[label=\"yes\"];<br>\"404\" -&gt; \"810\"[label=\"no\"];<br>\"810\" [label=\"0\\n impurity = 0.46281\\n samples = 11\\n value = [7 4]\",shape=box];<br>\"463\" [label=\"whole_weight &lt;= 0.833\\n impurity = 0.375\\n samples = 20\\n value = [ 5 15]\\n class = 1\", shape=ellipse];<br>\"463\" -&gt; \"927\"[label=\"yes\"];<br>\"927\" [label=\"1\\n impurity = 0\\n samples = 14\\n value = [ 0 14]\",shape=box];<br>\"463\" -&gt; \"928\"[label=\"no\"];<br>\"928\" [label=\"0\\n impurity = 0.277778\\n samples = 6\\n value = [5 1]\",shape=box];<br>\"464\" [label=\"height &lt;= 0.145\\n impurity = 0.152778\\n samples = 48\\n value = [ 4 44]\\n class = 1\", shape=ellipse];<br>\"464\" -&gt; \"929\"[label=\"yes\"];<br>\"929\" [label=\"1\\n impurity = 0\\n samples = 23\\n value = [ 0 23]\",shape=box];<br>\"464\" -&gt; \"930\"[label=\"no\"];<br>\"930\" [label=\"1\\n impurity = 0.2688\\n samples = 25\\n value = [ 4 21]\",shape=box];<br>\"504\" [label=\"whole_weight &lt;= 1.4095\\n impurity = 0.0185169\\n samples = 214\\n value = [ 2 212]\\n class = 1\", shape=ellipse];<br>\"504\" -&gt; \"1009\"[label=\"yes\"];<br>\"1009\" [label=\"1\\n impurity = 0.0399833\\n samples = 98\\n value = [ 2 96]\",shape=box];<br>\"504\" -&gt; \"1010\"[label=\"no\"];<br>\"1010\" [label=\"1\\n impurity = 0\\n samples = 116\\n value = [ 0 116]\",shape=box];<br>\"655\" [label=\"length &lt;= 0.435\\n impurity = 0.0907029\\n samples = 21\\n value = [ 1 20]\\n class = 1\", shape=ellipse];<br>\"655\" -&gt; \"1311\"[label=\"yes\"];<br>\"1311\" [label=\"1\\n impurity = 0.277778\\n samples = 6\\n value = [1 5]\",shape=box];<br>\"655\" -&gt; \"1312\"[label=\"no\"];<br>\"1312\" [label=\"1\\n impurity = 0\\n samples = 15\\n value = [ 0 15]\",shape=box];<br>\"657\" [label=\"shell_weight &lt;= 0.14\\n impurity = 0.42\\n samples = 20\\n value = [ 6 14]\\n class = 1\", shape=ellipse];<br>\"657\" -&gt; \"1315\"[label=\"yes\"];<br>\"1315\" [label=\"0\\n impurity = 0.444444\\n samples = 6\\n value = [4 2]\",shape=box];<br>\"657\" -&gt; \"1316\"[label=\"no\"];<br>\"1316\" [label=\"1\\n impurity = 0.244898\\n samples = 14\\n value = [ 2 12]\",shape=box];<br>\"809\" [label=\"height &lt;= 0.13\\n impurity = 0.156735\\n samples = 35\\n value = [ 3 32]\\n class = 1\", shape=ellipse];<br>\"809\" -&gt; \"1619\"[label=\"yes\"];<br>\"1619\" [label=\"1\\n impurity = 0.375\\n samples = 8\\n value = [2 6]\",shape=box];<br>\"809\" -&gt; \"1620\"[label=\"no\"];<br>\"1620\" [label=\"1\\n impurity = 0.0713306\\n samples = 27\\n value = [ 1 26]\",shape=box];<br><br>} //---end of digraph--------- </td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'digraph \"Classification tree for abalone_classif_train\" {\\n\"0\" [label=\"length <= 0.495\\\\n impurity = 0.453016\\\\n samples = 3011\\\\n value = [1044 196 ... (21658 characters truncated) ... \" -> \"1620\"[label=\"no\"];\\n\"1620\" [label=\"1\\\\n impurity = 0.0713306\\\\n samples = 27\\\\n value = [ 1 26]\",shape=box];\\n\\n} //---end of digraph--------- ',)]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"SELECT madlib.get_tree('abalone_rf_model', 1, 1, TRUE, TRUE); "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get variable importance:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"9 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>feature</th>\n",
" <th>oob_var_importance</th>\n",
" <th>impurity_var_importance</th>\n",
" </tr>\n",
" <tr>\n",
" <td>shucked_weight</td>\n",
" <td>12.0025318256</td>\n",
" <td>20.4821989904</td>\n",
" </tr>\n",
" <tr>\n",
" <td>shell_weight</td>\n",
" <td>31.2254609526</td>\n",
" <td>19.8024587933</td>\n",
" </tr>\n",
" <tr>\n",
" <td>length</td>\n",
" <td>23.0067732101</td>\n",
" <td>11.2012793065</td>\n",
" </tr>\n",
" <tr>\n",
" <td>viscera_weight</td>\n",
" <td>3.60975085939</td>\n",
" <td>10.8568589594</td>\n",
" </tr>\n",
" <tr>\n",
" <td>whole_weight</td>\n",
" <td>9.63781565072</td>\n",
" <td>10.4829872168</td>\n",
" </tr>\n",
" <tr>\n",
" <td>diameter</td>\n",
" <td>8.83966394397</td>\n",
" <td>9.84865534216</td>\n",
" </tr>\n",
" <tr>\n",
" <td>height</td>\n",
" <td>0.0</td>\n",
" <td>9.54744115005</td>\n",
" </tr>\n",
" <tr>\n",
" <td>sex_m</td>\n",
" <td>5.96519495962</td>\n",
" <td>4.53823943951</td>\n",
" </tr>\n",
" <tr>\n",
" <td>sex_f</td>\n",
" <td>5.712808598</td>\n",
" <td>3.23988080193</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'shucked_weight', 12.0025318256043, 20.4821989904084),\n",
" (u'shell_weight', 31.2254609525938, 19.8024587932943),\n",
" (u'length', 23.0067732101024, 11.2012793064942),\n",
" (u'viscera_weight', 3.60975085939451, 10.8568589593852),\n",
" (u'whole_weight', 9.63781565071796, 10.4829872167693),\n",
" (u'diameter', 8.83966394396736, 9.84865534216095),\n",
" (u'height', 0.0, 9.54744115004596),\n",
" (u'sex_m', 5.96519495962168, 4.53823943950833),\n",
" (u'sex_f', 5.71280859799804, 3.23988080193335)]"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_importances;\n",
"SELECT madlib.get_var_importance(\n",
" 'abalone_rf_model', -- model_table\n",
" 'abalone_rf_importances' -- output_table\n",
")\n",
";\n",
"SELECT *\n",
"FROM abalone_rf_importances\n",
"ORDER BY impurity_var_importance DESC\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Look at some of the predictions"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>forest_predict</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_test_proba;\n",
"SELECT\n",
"madlib.forest_predict(\n",
" 'abalone_rf_model', -- random_forest_model\n",
" 'abalone_classif_test', -- new_data_table\n",
" 'abalone_rf_test_proba', -- output_table\n",
" 'prob' -- type\n",
")\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>estimated_prob_0</th>\n",
" <th>estimated_prob_1</th>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>0.8</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(10, 0.0, 1.0),\n",
" (16, 1.0, 0.0),\n",
" (20, 0.8, 0.2),\n",
" (24, 0.0, 1.0),\n",
" (26, 0.0, 1.0)]"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * \n",
"FROM abalone_rf_test_proba\n",
"LIMIT 5"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1253 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_test_predict_actual;\n",
"CREATE TABLE abalone_rf_test_predict_actual\n",
"AS\n",
"SELECT \n",
" test.id,\n",
" prob.estimated_prob_1,\n",
" prob.estimated_prob_1 >= 0.5 as predicted_class,\n",
" test.mature as actual_class\n",
"FROM \n",
" abalone_rf_test_proba prob\n",
"INNER JOIN\n",
" abalone_classif_test test\n",
"ON\n",
" prob.id = test.id"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"11 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>threshold</th>\n",
" <th>tp</th>\n",
" <th>fp</th>\n",
" <th>fn</th>\n",
" <th>tn</th>\n",
" <th>tpr</th>\n",
" <th>tnr</th>\n",
" <th>ppv</th>\n",
" <th>npv</th>\n",
" <th>fpr</th>\n",
" <th>fdr</th>\n",
" <th>fnr</th>\n",
" <th>acc</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.0</td>\n",
" <td>824</td>\n",
" <td>429</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.657621707901</td>\n",
" <td>None</td>\n",
" <td>1.0</td>\n",
" <td>0.342378292099</td>\n",
" <td>0.0</td>\n",
" <td>0.657621707901</td>\n",
" <td>0.79345209436687530091</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>807</td>\n",
" <td>263</td>\n",
" <td>17</td>\n",
" <td>166</td>\n",
" <td>0.979368932039</td>\n",
" <td>0.386946386946</td>\n",
" <td>0.754205607477</td>\n",
" <td>0.907103825137</td>\n",
" <td>0.613053613054</td>\n",
" <td>0.245794392523</td>\n",
" <td>0.0206310679612</td>\n",
" <td>0.776536312849</td>\n",
" <td>0.85216473072861668427</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.2</td>\n",
" <td>800</td>\n",
" <td>226</td>\n",
" <td>24</td>\n",
" <td>203</td>\n",
" <td>0.970873786408</td>\n",
" <td>0.473193473193</td>\n",
" <td>0.779727095517</td>\n",
" <td>0.894273127753</td>\n",
" <td>0.526806526807</td>\n",
" <td>0.220272904483</td>\n",
" <td>0.0291262135922</td>\n",
" <td>0.800478850758</td>\n",
" <td>0.86486486486486486486</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.3</td>\n",
" <td>787</td>\n",
" <td>198</td>\n",
" <td>37</td>\n",
" <td>231</td>\n",
" <td>0.955097087379</td>\n",
" <td>0.538461538462</td>\n",
" <td>0.798984771574</td>\n",
" <td>0.861940298507</td>\n",
" <td>0.461538461538</td>\n",
" <td>0.201015228426</td>\n",
" <td>0.0449029126214</td>\n",
" <td>0.812450119713</td>\n",
" <td>0.87009397457158651189</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.4</td>\n",
" <td>769</td>\n",
" <td>167</td>\n",
" <td>55</td>\n",
" <td>262</td>\n",
" <td>0.933252427184</td>\n",
" <td>0.610722610723</td>\n",
" <td>0.821581196581</td>\n",
" <td>0.826498422713</td>\n",
" <td>0.389277389277</td>\n",
" <td>0.178418803419</td>\n",
" <td>0.0667475728155</td>\n",
" <td>0.822825219473</td>\n",
" <td>0.87386363636363636364</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.5</td>\n",
" <td>743</td>\n",
" <td>148</td>\n",
" <td>81</td>\n",
" <td>281</td>\n",
" <td>0.901699029126</td>\n",
" <td>0.655011655012</td>\n",
" <td>0.833894500561</td>\n",
" <td>0.776243093923</td>\n",
" <td>0.344988344988</td>\n",
" <td>0.166105499439</td>\n",
" <td>0.0983009708738</td>\n",
" <td>0.817238627294</td>\n",
" <td>0.86647230320699708455</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.6</td>\n",
" <td>719</td>\n",
" <td>119</td>\n",
" <td>105</td>\n",
" <td>310</td>\n",
" <td>0.872572815534</td>\n",
" <td>0.722610722611</td>\n",
" <td>0.85799522673</td>\n",
" <td>0.746987951807</td>\n",
" <td>0.277389277389</td>\n",
" <td>0.14200477327</td>\n",
" <td>0.127427184466</td>\n",
" <td>0.821229050279</td>\n",
" <td>0.86522262334536702768</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.7</td>\n",
" <td>690</td>\n",
" <td>97</td>\n",
" <td>134</td>\n",
" <td>332</td>\n",
" <td>0.837378640777</td>\n",
" <td>0.773892773893</td>\n",
" <td>0.876747141042</td>\n",
" <td>0.712446351931</td>\n",
" <td>0.226107226107</td>\n",
" <td>0.123252858958</td>\n",
" <td>0.162621359223</td>\n",
" <td>0.815642458101</td>\n",
" <td>0.85661080074487895717</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.8</td>\n",
" <td>658</td>\n",
" <td>79</td>\n",
" <td>166</td>\n",
" <td>350</td>\n",
" <td>0.79854368932</td>\n",
" <td>0.815850815851</td>\n",
" <td>0.892808683853</td>\n",
" <td>0.678294573643</td>\n",
" <td>0.184149184149</td>\n",
" <td>0.107191316147</td>\n",
" <td>0.20145631068</td>\n",
" <td>0.804469273743</td>\n",
" <td>0.84304932735426008969</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.9</td>\n",
" <td>619</td>\n",
" <td>60</td>\n",
" <td>205</td>\n",
" <td>369</td>\n",
" <td>0.751213592233</td>\n",
" <td>0.86013986014</td>\n",
" <td>0.911634756996</td>\n",
" <td>0.642857142857</td>\n",
" <td>0.13986013986</td>\n",
" <td>0.0883652430044</td>\n",
" <td>0.248786407767</td>\n",
" <td>0.788507581804</td>\n",
" <td>0.82368596141051230872</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>538</td>\n",
" <td>38</td>\n",
" <td>286</td>\n",
" <td>391</td>\n",
" <td>0.652912621359</td>\n",
" <td>0.911421911422</td>\n",
" <td>0.934027777778</td>\n",
" <td>0.577548005908</td>\n",
" <td>0.0885780885781</td>\n",
" <td>0.0659722222222</td>\n",
" <td>0.347087378641</td>\n",
" <td>0.741420590583</td>\n",
" <td>0.76857142857142857143</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.0, Decimal('824'), Decimal('429'), Decimal('0'), Decimal('0'), 1.0, 0.0, 0.657621707901037, None, 1.0, 0.342378292098962, 0.0, 0.657621707901037, Decimal('0.79345209436687530091')),\n",
" (0.1, Decimal('807'), Decimal('263'), Decimal('17'), Decimal('166'), 0.979368932038835, 0.386946386946387, 0.754205607476636, 0.907103825136612, 0.613053613053613, 0.245794392523364, 0.020631067961165, 0.776536312849162, Decimal('0.85216473072861668427')),\n",
" (0.2, Decimal('800'), Decimal('226'), Decimal('24'), Decimal('203'), 0.970873786407767, 0.473193473193473, 0.779727095516569, 0.894273127753304, 0.526806526806527, 0.220272904483431, 0.029126213592233, 0.80047885075818, Decimal('0.86486486486486486486')),\n",
" (0.3, Decimal('787'), Decimal('198'), Decimal('37'), Decimal('231'), 0.955097087378641, 0.538461538461538, 0.798984771573604, 0.861940298507463, 0.461538461538462, 0.201015228426396, 0.0449029126213592, 0.81245011971269, Decimal('0.87009397457158651189')),\n",
" (0.4, Decimal('769'), Decimal('167'), Decimal('55'), Decimal('262'), 0.933252427184466, 0.610722610722611, 0.821581196581197, 0.826498422712934, 0.389277389277389, 0.178418803418803, 0.066747572815534, 0.822825219473264, Decimal('0.87386363636363636364')),\n",
" (0.5, Decimal('743'), Decimal('148'), Decimal('81'), Decimal('281'), 0.901699029126214, 0.655011655011655, 0.833894500561167, 0.776243093922652, 0.344988344988345, 0.166105499438833, 0.0983009708737864, 0.817238627294493, Decimal('0.86647230320699708455')),\n",
" (0.6, Decimal('719'), Decimal('119'), Decimal('105'), Decimal('310'), 0.872572815533981, 0.722610722610723, 0.85799522673031, 0.746987951807229, 0.277389277389277, 0.14200477326969, 0.127427184466019, 0.82122905027933, Decimal('0.86522262334536702768')),\n",
" (0.7, Decimal('690'), Decimal('97'), Decimal('134'), Decimal('332'), 0.837378640776699, 0.773892773892774, 0.876747141041931, 0.71244635193133, 0.226107226107226, 0.123252858958069, 0.162621359223301, 0.815642458100559, Decimal('0.85661080074487895717')),\n",
" (0.8, Decimal('658'), Decimal('79'), Decimal('166'), Decimal('350'), 0.798543689320388, 0.815850815850816, 0.89280868385346, 0.678294573643411, 0.184149184149184, 0.10719131614654, 0.201456310679612, 0.804469273743017, Decimal('0.84304932735426008969')),\n",
" (0.9, Decimal('619'), Decimal('60'), Decimal('205'), Decimal('369'), 0.75121359223301, 0.86013986013986, 0.911634756995582, 0.642857142857143, 0.13986013986014, 0.0883652430044183, 0.24878640776699, 0.788507581803671, Decimal('0.82368596141051230872')),\n",
" (1.0, Decimal('538'), Decimal('38'), Decimal('286'), Decimal('391'), 0.652912621359223, 0.911421911421911, 0.934027777777778, 0.577548005908419, 0.0885780885780886, 0.0659722222222222, 0.347087378640777, 0.741420590582602, Decimal('0.76857142857142857143'))]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_test_binary_metrics;\n",
"SELECT\n",
"madlib.binary_classifier(\n",
" 'abalone_rf_test_predict_actual', -- table_in\n",
" 'abalone_rf_test_binary_metrics', -- table_out\n",
" 'estimated_prob_1', --prediction_col\n",
" 'actual_class' --observation_col\n",
")\n",
";\n",
"SELECT * \n",
"FROM abalone_rf_test_binary_metrics\n",
"ORDER BY threshold\n",
"LIMIT 15\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11 rows affected.\n"
]
}
],
"source": [
"#collect the false positive and true positive rates\n",
"rf_metrics = %sql SELECT fpr, tpr FROM abalone_rf_test_binary_metrics ORDER BY threshold;\n",
"rf_metrics = rf_metrics.DataFrame();"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x133b46650>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"rf_metrics.plot('fpr', 'tpr',xlim=(0.,1.), ylim=(0.,1.))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>area_under_roc</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.87614287007490890986052886029827777400595</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('0.87614287007490890986052886029827777400595'),)]"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_rf_test_auc CASCADE;\n",
"SELECT\n",
"madlib.area_under_roc(\n",
" 'abalone_rf_test_predict_actual', -- table_in\n",
" 'abalone_rf_test_auc', -- table_out\n",
" 'estimated_prob_1', --prediction_col\n",
" 'actual_class' --observation_col\n",
") as result\n",
";\n",
"\n",
"SELECT * FROM abalone_rf_test_auc ; -- look at the AUC for the random forest model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"regression\"></a>\n",
"# 6. Regression\n",
"\n",
"Before our target variable was a binary one that we constructed to represent maturity. An abalone was either mature or not mature. Now let's predict its age instead of the binary target. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"linear\"></a>\n",
"## 6a. Linear Regression"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>sex_f</th>\n",
" <th>sex_i</th>\n",
" <th>sex_m</th>\n",
" <th>rings</th>\n",
" <th>age</th>\n",
" <th>mature</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>0.2255</td>\n",
" <td>0.0995</td>\n",
" <td>0.0485</td>\n",
" <td>0.07</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>7</td>\n",
" <td>8.5</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>0.5095</td>\n",
" <td>0.2165</td>\n",
" <td>0.1125</td>\n",
" <td>0.165</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" <td>10.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 0.35, 0.265, 0.09, 0.2255, 0.0995, 0.0485, 0.07, 0, 0, 1, 7, Decimal('8.5'), 0),\n",
" (8, 0.475, 0.37, 0.125, 0.5095, 0.2165, 0.1125, 0.165, 0, 0, 1, 9, Decimal('10.5'), 1)]"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * \n",
"FROM abalone_classif_train\n",
"LIMIT 2"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>r2</th>\n",
" <th>std_err</th>\n",
" <th>t_stats</th>\n",
" <th>p_values</th>\n",
" <th>condition_no</th>\n",
" <th>bp_stats</th>\n",
" <th>bp_p_value</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_missing_rows_skipped</th>\n",
" <th>variance_covariance</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[4.410357699688, 1.27175722855722, 9.64141451253988, 9.33882415610471, 9.5183803866192, -20.1419380713264, -12.1045936650436, 8.74328089164249, 0.817371808454084, 0.876398154242832]</td>\n",
" <td>0.545016598</td>\n",
" <td>[0.317420812747666, 2.14523563920031, 2.63744840317991, 1.71067733937981, 0.892068852840643, 1.00753029831926, 1.58728493291856, 1.39432077797183, 0.123340170109636, 0.11520852987613]</td>\n",
" <td>[13.894355765493, 0.592828687589443, 3.65558412476067, 5.45913828465771, 10.6700064196945, -19.991396888934, -7.62597402268961, 6.27063802660992, 6.6269716324173, 7.60705960908552]</td>\n",
" <td>[1.51375990657579e-42, 0.55334180633369, 0.000261129452464093, 5.18665415881891e-08, 4.23697868795871e-26, 2.00173103596502e-83, 3.25682639467242e-14, 4.12820640447016e-10, 4.06638451488027e-11, 3.7597744950155e-14]</td>\n",
" <td>136.28036679</td>\n",
" <td>345.394268412</td>\n",
" <td>5.91944435666e-69</td>\n",
" <td>2924</td>\n",
" <td>0</td>\n",
" <td>[[0.100755972365389, -0.263855054760177, -0.00609972532228305, -0.0726639103174352, 0.00706375431327825, 0.0402300840383149, 0.0648270560387451, 0.0778928701063708, 0.00271915197088884, 0.000835792028445214], [-0.263855054760176, 4.60203594769516, -5.0528770327474, -0.117336658885753, -0.0025167905987354, -0.13242792154729, -0.287715350508991, 0.103398664665856, 0.0127891124036783, 0.00974470709032043], [-0.00609972532228338, -5.0528770327474, 6.95613407943628, -0.481388937937679, 0.00800125606520873, -0.0374834933330533, 0.119971199934789, -0.476693785352631, -0.0318420322134661, -0.0229396989267095], [-0.0726639103174352, -0.117336658885753, -0.481388937937679, 2.92641695946758, -0.00475647086180627, 0.0307350668680895, -0.105488059797933, -0.232151998596154, -0.018740912150196, -0.0136313984248556], [0.0070637543132782, -0.00251679059873537, 0.00800125606520888, -0.00475647086180627, 0.79578683820842, -0.747049900168081, -0.87799661808084, -0.998391954375627, -0.00417666761345224, -0.0036043942565332], [0.0402300840383148, -0.13242792154729, -0.0374834933330533, 0.0307350668680895, -0.747049900168081, 1.01511730203129, 0.425218416562224, 0.894863457513615, 0.0101263791569862, 0.00369931221671126], [0.0648270560387451, -0.287715350508991, 0.119971199934789, -0.105488059797933, -0.87799661808084, 0.425218416562224, 2.51947345827029, 0.747389592520039, -0.0129601341597446, -0.00808378318425161], [0.0778928701063709, 0.103398664665856, -0.476693785352631, -0.232151998596154, -0.998391954375627, 0.894863457513615, 0.747389592520039, 1.94413043188397, 0.00107506521985202, 0.00309723366682994], [0.00271915197088884, 0.0127891124036783, -0.0318420322134661, -0.018740912150196, -0.00417666761345224, 0.0101263791569862, -0.0129601341597446, 0.00107506521985202, 0.015212797562674, 0.00918132677074068], [0.000835792028445215, 0.00974470709032043, -0.0229396989267095, -0.0136313984248556, -0.0036043942565332, 0.00369931221671126, -0.00808378318425161, 0.00309723366682994, 0.00918132677074068, 0.0132730053562192]]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([4.410357699688, 1.27175722855722, 9.64141451253988, 9.33882415610471, 9.5183803866192, -20.1419380713264, -12.1045936650436, 8.74328089164249, 0.817371808454084, 0.876398154242832], 0.545016598000415, [0.317420812747666, 2.14523563920031, 2.63744840317991, 1.71067733937981, 0.892068852840643, 1.00753029831926, 1.58728493291856, 1.39432077797183, 0.123340170109636, 0.11520852987613], [13.894355765493, 0.592828687589443, 3.65558412476067, 5.45913828465771, 10.6700064196945, -19.991396888934, -7.62597402268961, 6.27063802660992, 6.6269716324173, 7.60705960908552], [1.51375990657579e-42, 0.55334180633369, 0.000261129452464093, 5.18665415881891e-08, 4.23697868795871e-26, 2.00173103596502e-83, 3.25682639467242e-14, 4.12820640447016e-10, 4.06638451488027e-11, 3.7597744950155e-14], 136.280366790125, 345.394268412272, 5.91944435666072e-69, 2924L, 0L, [[0.100755972365389, -0.263855054760177, -0.00609972532228305, -0.0726639103174352, 0.00706375431327825, 0.0402300840383149, 0.0648270560387451, 0.077 ... (1738 characters truncated) ... 5, -0.0136313984248556, -0.0036043942565332, 0.00369931221671126, -0.00808378318425161, 0.00309723366682994, 0.00918132677074068, 0.0132730053562192]])]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_linreg_model;\n",
"DROP TABLE IF EXISTS abalone_linreg_model_summary;\n",
"SELECT madlib.linregr_train(\n",
" 'abalone_classif_train', -- source_table\n",
" 'abalone_linreg_model', -- out_table\n",
" 'age', -- dependent_varname\n",
" 'ARRAY[\n",
" 1,\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ]', -- independent_varname\n",
" NULL, -- grouping_cols\n",
" TRUE -- heteroskedasticity_option\n",
")\n",
";\n",
"SELECT * FROM abalone_linreg_model\n",
"LIMIT 10\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Look at the predictions from the Linear Regression Model"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1253 rows affected.\n",
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>predicted_age</th>\n",
" <th>actual_age</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3918</td>\n",
" <td>14.0356025263</td>\n",
" <td>19.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>684</td>\n",
" <td>11.9578243725</td>\n",
" <td>11.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1664</td>\n",
" <td>10.2519022088</td>\n",
" <td>10.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3936</td>\n",
" <td>12.7607562191</td>\n",
" <td>14.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>654</td>\n",
" <td>9.80860239621</td>\n",
" <td>11.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3918, 14.0356025263472, Decimal('19.5')),\n",
" (684, 11.9578243724615, Decimal('11.5')),\n",
" (1664, 10.2519022087522, Decimal('10.5')),\n",
" (3936, 12.760756219147, Decimal('14.5')),\n",
" (654, 9.80860239621498, Decimal('11.5'))]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_linreg_test_predict;\n",
"CREATE TABLE abalone_linreg_test_predict\n",
"AS\n",
"SELECT \n",
" test.id,\n",
" madlib.linregr_predict(\n",
" coef, \n",
" ARRAY[\n",
" 1,\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ] \n",
" ) as predicted_age,\n",
" test.age as actual_age\n",
"FROM abalone_classif_test test, abalone_linreg_model model\n",
";\n",
"SELECT * FROM abalone_linreg_test_predict\n",
"LIMIT 5\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mean_squared_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>4.70025372012</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(4.70025372011683,)]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"drop table if exists abalone_linreg_test_predict_mse;\n",
"SELECT madlib.mean_squared_error(\n",
" 'abalone_linreg_test_predict', -- table_in\n",
" 'abalone_linreg_test_predict_mse', -- table_out\n",
" 'predicted_age', -- prediction_col\n",
" 'actual_age' -- observed_col\n",
") as result\n",
";\n",
"SELECT * FROM abalone_linreg_test_predict_mse;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"elastic\"></a>\n",
"## 6b. Elastic Net Regression\n",
"\n",
"Elastic Net Regression is linear regression with penalties assigned to the size of the coefficients. Note that \n",
"\n",
"MADlib's elastic net automatically fits an intercept, so you shouldn't include an explicit intercept column of 1's in your independent variable array."
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>elastic_net_train</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_elasticnet_model CASCADE;\n",
"DROP TABLE IF EXISTS abalone_elasticnet_model_summary CASCADE;\n",
"SELECT madlib.elastic_net_train( \n",
" 'abalone_classif_train', -- tbl_source\n",
" 'abalone_elasticnet_model', -- tbl_result\n",
" 'age', -- col_dep_var\n",
" 'ARRAY[\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ]', -- col_ind_var\n",
" 'gaussian', -- regress_family\n",
" 0.5, -- alpha\n",
" 0.5, -- lambda_value\n",
" TRUE -- standardize\n",
" --, -- grouping_col\n",
" --, -- optimizer\n",
" --, -- optimizer_params\n",
" --, -- excluded\n",
" --, -- max_iter\n",
" -- tolerance\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>family</th>\n",
" <th>features</th>\n",
" <th>features_selected</th>\n",
" <th>coef_nonzero</th>\n",
" <th>coef_all</th>\n",
" <th>intercept</th>\n",
" <th>log_likelihood</th>\n",
" <th>standardize</th>\n",
" <th>iteration_run</th>\n",
" </tr>\n",
" <tr>\n",
" <td>gaussian</td>\n",
" <td>[u'[1]', u'[2]', u'[3]', u'[4]', u'[5]', u'[6]', u'[7]', u'[8]', u'[9]']</td>\n",
" <td>[u'[1]', u'[2]', u'[3]', u'[7]', u'[8]']</td>\n",
" <td>[1.3220343857, 3.14984818966, 8.52839367593, 6.31126111497, 0.120956346376]</td>\n",
" <td>[1.3220343857, 3.14984818966, 8.52839367593, 0.0, 0.0, 0.0, 6.31126111497, 0.120956346376, 0.0]</td>\n",
" <td>6.74458430257</td>\n",
" <td>-3.83846837185</td>\n",
" <td>True</td>\n",
" <td>193</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'gaussian', [u'[1]', u'[2]', u'[3]', u'[4]', u'[5]', u'[6]', u'[7]', u'[8]', u'[9]'], [u'[1]', u'[2]', u'[3]', u'[7]', u'[8]'], [1.3220343857, 3.14984818966, 8.52839367593, 6.31126111497, 0.120956346376], [1.3220343857, 3.14984818966, 8.52839367593, 0.0, 0.0, 0.0, 6.31126111497, 0.120956346376, 0.0], 6.74458430257, -3.83846837185, True, 193)]"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_elasticnet_model"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>method</th>\n",
" <th>source_table</th>\n",
" <th>out_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>family</th>\n",
" <th>alpha</th>\n",
" <th>lambda_value</th>\n",
" <th>grouping_col</th>\n",
" <th>num_all_groups</th>\n",
" <th>num_failed_groups</th>\n",
" </tr>\n",
" <tr>\n",
" <td>elastic_net</td>\n",
" <td>abalone_classif_train</td>\n",
" <td>abalone_elasticnet_model</td>\n",
" <td>age</td>\n",
" <td>ARRAY[<br> length,<br> diameter,<br> height,<br> whole_weight,<br> shucked_weight,<br> viscera_weight,<br> shell_weight,<br> sex_f,<br> sex_m<br> ]</td>\n",
" <td>gaussian</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>NULL</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'elastic_net', u'abalone_classif_train', u'abalone_elasticnet_model', u'age', u'ARRAY[\\n length,\\n diameter,\\n height,\\n whole_weight,\\n shucked_weight,\\n viscera_weight,\\n shell_weight,\\n sex_f,\\n sex_m\\n ]', u'gaussian', 0.5, 0.5, u'NULL', 1, 0)]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_elasticnet_model_summary"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1253 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_elasticnet_test_predict;\n",
"CREATE TABLE abalone_elasticnet_test_predict\n",
"AS\n",
"SELECT \n",
" test.id,\n",
" madlib.elastic_net_gaussian_predict(\n",
" model.coef_all, \n",
" model.intercept,\n",
" ARRAY[\n",
" length,\n",
" diameter,\n",
" height,\n",
" whole_weight,\n",
" shucked_weight,\n",
" viscera_weight,\n",
" shell_weight,\n",
" sex_f,\n",
" sex_m\n",
" ] \n",
" ) as predicted_age,\n",
" test.age as actual_age\n",
"FROM abalone_classif_test test, abalone_elasticnet_model model\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>predicted_age</th>\n",
" <th>actual_age</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3941</td>\n",
" <td>12.5704131046</td>\n",
" <td>16.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3441</td>\n",
" <td>10.3744669188</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>12.9497348984</td>\n",
" <td>14.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>849</td>\n",
" <td>11.982881522</td>\n",
" <td>11.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>224</td>\n",
" <td>10.029998676</td>\n",
" <td>11.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>651</td>\n",
" <td>9.10540972598</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1907</td>\n",
" <td>11.6905236073</td>\n",
" <td>10.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2603</td>\n",
" <td>12.8194751477</td>\n",
" <td>10.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>592</td>\n",
" <td>11.5163675347</td>\n",
" <td>19.5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>261</td>\n",
" <td>11.7751231628</td>\n",
" <td>13.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3941, 12.5704131046376, Decimal('16.5')),\n",
" (3441, 10.374466918844, Decimal('8.5')),\n",
" (92, 12.9497348983795, Decimal('14.5')),\n",
" (849, 11.9828815219954, Decimal('11.5')),\n",
" (224, 10.0299986760329, Decimal('11.5')),\n",
" (651, 9.1054097259781, Decimal('7.5')),\n",
" (1907, 11.6905236073399, Decimal('10.5')),\n",
" (2603, 12.8194751476722, Decimal('10.5')),\n",
" (592, 11.5163675346797, Decimal('19.5')),\n",
" (261, 11.7751231627786, Decimal('13.5'))]"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_elasticnet_test_predict LIMIT 10"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mean_squared_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6.28606344646</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6.28606344645616,)]"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_elasticnet_test_predict_mse;\n",
"SELECT madlib.mean_squared_error(\n",
" 'abalone_elasticnet_test_predict', -- table_in\n",
" 'abalone_elasticnet_test_predict_mse', -- table_out\n",
" 'predicted_age', -- prediction_col\n",
" 'actual_age' -- observed_col\n",
") \n",
";\n",
"SELECT * FROM abalone_elasticnet_test_predict_mse\n",
";"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
},
"toc": {
"base_numbering": 1,
"nav_menu": {
"height": "235.994px",
"width": "273.991px"
},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "208px"
},
"toc_section_display": true,
"toc_window_display": true
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}