| /* ----------------------------------------------------------------------- *//** |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| * |
| *//* ----------------------------------------------------------------------- */ |
| |
| /* ----------------------------------------------------------------------------- |
| * Test Multinomial Logistic Regression. |
| * -------------------------------------------------------------------------- */ |
| |
| /* |
| * The values given by the multinomial logistic regression were cross checked |
| * with the Matlab command mnrfit, which is documented at |
| * http://www.mathworks.com/help/toolbox/stats/mnrfit.html |
| * |
| * One important detail in the mnrfit command is that due to a difference in convention, |
| * its answers for the coefficients are the negative of our coefficient. Our |
| * convention is chosen to match the convention of the binary |
| * logistic regression implementation in madlib. |
| * |
| * For completeness, the matlab code needed to check the answers to the 'test3' example |
| * is included below. The code assumes that the data is contained in a csv file |
| * and that the columns haven't changed order. The coefficients will be in the |
| * 'B' variable. |
| * |
| * BEGIN CODE |
| * |
| data = csvread(csvFilename); |
| N = size(data, 1); % Number of records |
| J = size(data, 2)-1; % Number of covariates |
| |
| % Integer encoded categories {0,1...K-1} |
| int_y = 1+data(:,end); % Categories |
| x = data(:,1:end-1); % Independant variables |
| |
| % Pivot around the last data point |
| [B,dev,stats] = mnrfit(x,int_y) |
| * |
| * END CODE |
| */ |
| |
| DROP TABLE IF EXISTS multinom_test; |
| |
| CREATE TABLE multinom_test ( |
| feat1 INTEGER, |
| feat2 INTEGER, |
| cat INTEGER, |
| g CHAR |
| ); |
| |
| INSERT INTO multinom_test(feat1, feat2, cat, g) VALUES |
| (1,35,1,'A'), |
| (2,33,0,'A'), |
| (3,39,1,'A'), |
| (1,37,1,'A'), |
| (2,31,1,'A'), |
| (3,36,0,'A'), |
| (2,36,1,'A'), |
| (2,31,1,'A'), |
| (2,41,1,'A'), |
| (2,37,1,'A'), |
| (1,44,1,'A'), |
| (3,33,2,'A'), |
| (1,31,1,'A'), |
| (2,44,1,'A'), |
| (1,35,1,'A'), |
| (1,44,0,'A'), |
| (1,46,0,'A'), |
| (2,46,1,'A'), |
| (2,46,2,'A'), |
| (3,49,1,'A'), |
| (2,39,0,'A'), |
| (2,44,1,'A'), |
| (1,47,1,'A'), |
| (1,44,1,'A'), |
| (1,37,2,'A'), |
| (3,38,2,'A'), |
| (1,49,0,'A'), |
| (2,44,0,'A'), |
| (1,41,2,'A'), |
| (1,50,2,'A'), |
| (2,44,0,'A'), |
| (1,39,1,'A'), |
| (1,40,2,'A'), |
| (1,46,2,'A'), |
| (2,41,1,'A'), |
| (2,39,1,'A'), |
| (2,33,1,'A'), |
| (3,59,2,'A'), |
| (1,41,0,'A'), |
| (2,47,2,'A'), |
| (2,31,0,'A'), |
| (3,42,2,'A'), |
| (1,55,2,'A'), |
| (3,40,1,'A'), |
| (1,44,2,'A'), |
| (1,54,1,'A'), |
| (2,46,1,'A'), |
| (1,54,0,'A'), |
| (2,42,1,'A'), |
| (2,49,2,'A'), |
| (2,41,2,'A'), |
| (2,41,1,'A'), |
| (1,44,0,'A'), |
| (1,57,2,'A'), |
| (2,52,2,'A'), |
| (1,49,0,'A'), |
| (3,41,2,'A'), |
| (3,57,0,'A'), |
| (1,62,1,'A'), |
| (3,33,0,'A'), |
| (2,54,1,'A'), |
| (2,40,2,'A'), |
| (3,52,2,'A'), |
| (2,57,1,'A'), |
| (2,49,1,'A'), |
| (2,46,1,'A'), |
| (1,57,0,'A'), |
| (2,49,2,'A'), |
| (2,52,2,'A'), |
| (2,53,0,'A'), |
| (3,54,2,'A'), |
| (2,57,2,'A'), |
| (3,41,2,'A'), |
| (1,52,0,'A'), |
| (2,57,1,'A'), |
| (1,54,0,'A'), |
| (2,52,1,'A'), |
| (2,52,0,'A'), |
| (2,44,0,'A'), |
| (2,46,2,'A'), |
| (1,49,1,'A'), |
| (2,54,2,'A'), |
| (3,52,2,'A'); |
| |
| SELECT multinom( |
| 'multinom_test', |
| 'mglm_test', |
| 'cat', |
| 'ARRAY[feat1, feat2, 1]', |
| '2', |
| 'logit', |
| NULL, |
| NULL, |
| TRUE); |
| |
| SELECT ordinal( |
| 'multinom_test', |
| 'ordinal_out', |
| 'cat', |
| 'ARRAY[feat1, feat2]', |
| '0<1<2', |
| 'logit' |
| ); |
| |
| |
| DROP TABLE IF EXISTS abalone CASCADE; |
| |
| CREATE TABLE abalone ( |
| id integer, |
| sex text, |
| length double precision, |
| diameter double precision, |
| height double precision, |
| whole double precision, |
| shucked double precision, |
| viscera double precision, |
| shell double precision, |
| rings integer |
| ); |
| |
| INSERT INTO abalone VALUES |
| (3151, 'F', 0.655000000000000027, 0.505000000000000004, 0.165000000000000008, 1.36699999999999999, 0.583500000000000019, 0.351499999999999979, 0.396000000000000019, 10), |
| (2026, 'F', 0.550000000000000044, 0.469999999999999973, 0.149999999999999994, 0.920499999999999985, 0.381000000000000005, 0.243499999999999994, 0.267500000000000016, 10), |
| (3751, 'I', 0.434999999999999998, 0.375, 0.110000000000000001, 0.41549999999999998, 0.170000000000000012, 0.0759999999999999981, 0.14499999999999999, 8), |
| (720, 'I', 0.149999999999999994, 0.100000000000000006, 0.0250000000000000014, 0.0149999999999999994, 0.00449999999999999966, 0.00400000000000000008, 0.0050000000000000001, 2), |
| (1635, 'F', 0.574999999999999956, 0.469999999999999973, 0.154999999999999999, 1.1160000000000001, 0.509000000000000008, 0.237999999999999989, 0.340000000000000024, 10), |
| (2648, 'I', 0.5, 0.390000000000000013, 0.125, 0.582999999999999963, 0.293999999999999984, 0.132000000000000006, 0.160500000000000004, 8), |
| (1796, 'F', 0.57999999999999996, 0.429999999999999993, 0.170000000000000012, 1.47999999999999998, 0.65349999999999997, 0.32400000000000001, 0.41549999999999998, 10), |
| (209, 'F', 0.525000000000000022, 0.41499999999999998, 0.170000000000000012, 0.832500000000000018, 0.275500000000000023, 0.168500000000000011, 0.309999999999999998, 13), |
| (1451, 'I', 0.455000000000000016, 0.33500000000000002, 0.135000000000000009, 0.501000000000000001, 0.274000000000000021, 0.0995000000000000051, 0.106499999999999997, 7), |
| (1108, 'I', 0.510000000000000009, 0.380000000000000004, 0.115000000000000005, 0.515499999999999958, 0.214999999999999997, 0.113500000000000004, 0.166000000000000009, 8), |
| (3675, 'F', 0.594999999999999973, 0.450000000000000011, 0.165000000000000008, 1.08099999999999996, 0.489999999999999991, 0.252500000000000002, 0.279000000000000026, 12), |
| (2108, 'F', 0.675000000000000044, 0.550000000000000044, 0.179999999999999993, 1.68849999999999989, 0.562000000000000055, 0.370499999999999996, 0.599999999999999978, 15), |
| (3312, 'F', 0.479999999999999982, 0.380000000000000004, 0.135000000000000009, 0.507000000000000006, 0.191500000000000004, 0.13650000000000001, 0.154999999999999999, 12), |
| (882, 'M', 0.655000000000000027, 0.520000000000000018, 0.165000000000000008, 1.40949999999999998, 0.585999999999999965, 0.290999999999999981, 0.405000000000000027, 9), |
| (3402, 'M', 0.479999999999999982, 0.395000000000000018, 0.149999999999999994, 0.681499999999999995, 0.214499999999999996, 0.140500000000000014, 0.2495, 18), |
| (829, 'I', 0.409999999999999976, 0.325000000000000011, 0.100000000000000006, 0.394000000000000017, 0.20799999999999999, 0.0655000000000000027, 0.105999999999999997, 6), |
| (1305, 'M', 0.535000000000000031, 0.434999999999999998, 0.149999999999999994, 0.716999999999999971, 0.347499999999999976, 0.14449999999999999, 0.194000000000000006, 9), |
| (3613, 'M', 0.599999999999999978, 0.46000000000000002, 0.179999999999999993, 1.1399999999999999, 0.422999999999999987, 0.257500000000000007, 0.364999999999999991, 10), |
| (1068, 'I', 0.340000000000000024, 0.265000000000000013, 0.0800000000000000017, 0.201500000000000012, 0.0899999999999999967, 0.0475000000000000006, 0.0550000000000000003, 5), |
| (2446, 'M', 0.5, 0.380000000000000004, 0.135000000000000009, 0.583500000000000019, 0.22950000000000001, 0.126500000000000001, 0.179999999999999993, 12), |
| (1393, 'M', 0.635000000000000009, 0.474999999999999978, 0.170000000000000012, 1.19350000000000001, 0.520499999999999963, 0.269500000000000017, 0.366499999999999992, 10), |
| (359, 'M', 0.744999999999999996, 0.584999999999999964, 0.214999999999999997, 2.49900000000000011, 0.92649999999999999, 0.471999999999999975, 0.699999999999999956, 17), |
| (549, 'F', 0.564999999999999947, 0.450000000000000011, 0.160000000000000003, 0.79500000000000004, 0.360499999999999987, 0.155499999999999999, 0.23000000000000001, 12), |
| (1154, 'F', 0.599999999999999978, 0.474999999999999978, 0.160000000000000003, 1.02649999999999997, 0.484999999999999987, 0.2495, 0.256500000000000006, 9), |
| (1790, 'F', 0.54500000000000004, 0.385000000000000009, 0.149999999999999994, 1.11850000000000005, 0.542499999999999982, 0.244499999999999995, 0.284499999999999975, 9), |
| (3703, 'F', 0.665000000000000036, 0.540000000000000036, 0.195000000000000007, 1.76400000000000001, 0.850500000000000034, 0.361499999999999988, 0.469999999999999973, 11), |
| (1962, 'F', 0.655000000000000027, 0.515000000000000013, 0.179999999999999993, 1.41199999999999992, 0.619500000000000051, 0.248499999999999999, 0.496999999999999997, 11), |
| (1665, 'I', 0.604999999999999982, 0.469999999999999973, 0.14499999999999999, 0.802499999999999991, 0.379000000000000004, 0.226500000000000007, 0.220000000000000001, 9), |
| (635, 'M', 0.359999999999999987, 0.294999999999999984, 0.100000000000000006, 0.210499999999999993, 0.0660000000000000031, 0.0524999999999999981, 0.0749999999999999972, 9), |
| (3901, 'M', 0.445000000000000007, 0.344999999999999973, 0.140000000000000013, 0.475999999999999979, 0.205499999999999988, 0.101500000000000007, 0.108499999999999999, 15), |
| (2734, 'I', 0.41499999999999998, 0.33500000000000002, 0.100000000000000006, 0.357999999999999985, 0.169000000000000011, 0.067000000000000004, 0.104999999999999996, 7), |
| (3856, 'M', 0.409999999999999976, 0.33500000000000002, 0.115000000000000005, 0.440500000000000003, 0.190000000000000002, 0.0850000000000000061, 0.135000000000000009, 8), |
| (827, 'I', 0.395000000000000018, 0.28999999999999998, 0.0950000000000000011, 0.303999999999999992, 0.127000000000000002, 0.0840000000000000052, 0.076999999999999999, 6), |
| (3381, 'I', 0.190000000000000002, 0.130000000000000004, 0.0449999999999999983, 0.0264999999999999993, 0.00899999999999999932, 0.0050000000000000001, 0.00899999999999999932, 5), |
| (3972, 'I', 0.400000000000000022, 0.294999999999999984, 0.0950000000000000011, 0.252000000000000002, 0.110500000000000001, 0.0575000000000000025, 0.0660000000000000031, 6), |
| (1155, 'M', 0.599999999999999978, 0.455000000000000016, 0.170000000000000012, 1.1915, 0.695999999999999952, 0.239499999999999991, 0.239999999999999991, 8), |
| (3467, 'M', 0.640000000000000013, 0.5, 0.170000000000000012, 1.4544999999999999, 0.642000000000000015, 0.357499999999999984, 0.353999999999999981, 9), |
| (2433, 'F', 0.609999999999999987, 0.484999999999999987, 0.165000000000000008, 1.08699999999999997, 0.425499999999999989, 0.232000000000000012, 0.380000000000000004, 11), |
| (552, 'I', 0.614999999999999991, 0.489999999999999991, 0.154999999999999999, 0.988500000000000045, 0.41449999999999998, 0.195000000000000007, 0.344999999999999973, 13), |
| (1425, 'F', 0.729999999999999982, 0.57999999999999996, 0.190000000000000002, 1.73750000000000004, 0.678499999999999992, 0.434499999999999997, 0.520000000000000018, 11), |
| (2402, 'F', 0.584999999999999964, 0.41499999999999998, 0.154999999999999999, 0.69850000000000001, 0.299999999999999989, 0.145999999999999991, 0.195000000000000007, 12), |
| (1748, 'F', 0.699999999999999956, 0.535000000000000031, 0.174999999999999989, 1.77299999999999991, 0.680499999999999994, 0.479999999999999982, 0.512000000000000011, 15), |
| (3983, 'I', 0.57999999999999996, 0.434999999999999998, 0.149999999999999994, 0.891499999999999959, 0.362999999999999989, 0.192500000000000004, 0.251500000000000001, 6), |
| (335, 'F', 0.739999999999999991, 0.599999999999999978, 0.195000000000000007, 1.97399999999999998, 0.597999999999999976, 0.408499999999999974, 0.709999999999999964, 16), |
| (1587, 'I', 0.515000000000000013, 0.349999999999999978, 0.104999999999999996, 0.474499999999999977, 0.212999999999999995, 0.122999999999999998, 0.127500000000000002, 10), |
| (2448, 'I', 0.275000000000000022, 0.204999999999999988, 0.0800000000000000017, 0.096000000000000002, 0.0359999999999999973, 0.0184999999999999991, 0.0299999999999999989, 6), |
| (1362, 'F', 0.604999999999999982, 0.474999999999999978, 0.174999999999999989, 1.07600000000000007, 0.463000000000000023, 0.219500000000000001, 0.33500000000000002, 9), |
| (2799, 'M', 0.640000000000000013, 0.484999999999999987, 0.149999999999999994, 1.09800000000000009, 0.519499999999999962, 0.222000000000000003, 0.317500000000000004, 10), |
| (1413, 'F', 0.67000000000000004, 0.505000000000000004, 0.174999999999999989, 1.01449999999999996, 0.4375, 0.271000000000000019, 0.3745, 10), |
| (1739, 'F', 0.67000000000000004, 0.540000000000000036, 0.195000000000000007, 1.61899999999999999, 0.739999999999999991, 0.330500000000000016, 0.465000000000000024, 11), |
| (1152, 'M', 0.584999999999999964, 0.465000000000000024, 0.160000000000000003, 0.955500000000000016, 0.45950000000000002, 0.235999999999999988, 0.265000000000000013, 7), |
| (2427, 'I', 0.564999999999999947, 0.434999999999999998, 0.154999999999999999, 0.782000000000000028, 0.271500000000000019, 0.16800000000000001, 0.284999999999999976, 14), |
| (1777, 'M', 0.484999999999999987, 0.369999999999999996, 0.154999999999999999, 0.967999999999999972, 0.418999999999999984, 0.245499999999999996, 0.236499999999999988, 9), |
| (3294, 'M', 0.574999999999999956, 0.455000000000000016, 0.184999999999999998, 1.15599999999999992, 0.552499999999999991, 0.242999999999999994, 0.294999999999999984, 13), |
| (1403, 'M', 0.650000000000000022, 0.510000000000000009, 0.190000000000000002, 1.54200000000000004, 0.715500000000000025, 0.373499999999999999, 0.375, 9), |
| (2256, 'M', 0.510000000000000009, 0.395000000000000018, 0.14499999999999999, 0.61850000000000005, 0.215999999999999998, 0.138500000000000012, 0.239999999999999991, 12), |
| (3984, 'F', 0.584999999999999964, 0.450000000000000011, 0.125, 0.873999999999999999, 0.354499999999999982, 0.20749999999999999, 0.225000000000000006, 6), |
| (1116, 'M', 0.525000000000000022, 0.405000000000000027, 0.119999999999999996, 0.755499999999999949, 0.3755, 0.155499999999999999, 0.201000000000000012, 9), |
| (1366, 'M', 0.609999999999999987, 0.474999999999999978, 0.170000000000000012, 1.02649999999999997, 0.434999999999999998, 0.233500000000000013, 0.303499999999999992, 10), |
| (3759, 'I', 0.525000000000000022, 0.400000000000000022, 0.140000000000000013, 0.605500000000000038, 0.260500000000000009, 0.107999999999999999, 0.209999999999999992, 9); |
| |
| DROP TABLE IF EXISTS abalone_out, abalone_out_summary; |
| SELECT glm( |
| 'abalone', |
| 'abalone_out', |
| 'rings', |
| 'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]', |
| 'family=inverse_gaussian, link=sqr_inverse', NULL, 'max_iter=2, tolerance=1e-16' |
| ); |
| |
| DROP TABLE IF EXISTS abalone_out, abalone_out_summary; |
| SELECT glm( |
| 'abalone', |
| 'abalone_out', |
| 'rings', |
| 'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]', |
| 'family=gaussian, link=identity', NULL, 'max_iter=2' |
| ); |
| |
| DROP TABLE IF EXISTS abalone_out, abalone_out_summary; |
| SELECT glm( |
| 'abalone', |
| 'abalone_out', |
| 'rings', |
| 'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]', |
| 'family=gamma, link=inverse', NULL, 'max_iter=2, tolerance=1e-16' |
| ); |
| |
| DROP TABLE IF EXISTS abalone_probit_out, abalone_probit_out_summary; |
| SELECT glm( |
| 'abalone', |
| 'abalone_probit_out', |
| 'rings < 10', |
| 'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]', |
| 'family=binomial, link=probit', NULL, 'max_iter=2, tolerance=1e-16' |
| ); |
| |
| DROP TABLE IF EXISTS warpbreaks CASCADE; |
| |
| CREATE TABLE warpbreaks( |
| id serial, |
| breaks integer, |
| wool char(1), |
| tension char(1), |
| g char(1) |
| ); |
| |
| INSERT INTO warpbreaks(breaks, wool, tension, g) VALUES |
| (26, 'A', 'L', '1'), |
| (30, 'A', 'L', '1'), |
| (54, 'A', 'L', '1'), |
| (25, 'A', 'L', '1'), |
| (70, 'A', 'L', '1'), |
| (52, 'A', 'L', '1'), |
| (51, 'A', 'L', '1'), |
| (26, 'A', 'L', '1'), |
| (67, 'A', 'L', '1'), |
| (18, 'A', 'M', '1'), |
| (21, 'A', 'M', '1'), |
| (29, 'A', 'M', '1'), |
| (17, 'A', 'M', '1'), |
| (12, 'A', 'M', '1'), |
| (18, 'A', 'M', '1'), |
| (35, 'A', 'M', '1'), |
| (30, 'A', 'M', '1'), |
| (36, 'A', 'M', '1'), |
| (36, 'A', 'H', '0'), |
| (21, 'A', 'H', '0'), |
| (24, 'A', 'H', '0'), |
| (18, 'A', 'H', '0'), |
| (10, 'A', 'H', '0'), |
| (43, 'A', 'H', '0'), |
| (28, 'A', 'H', '0'), |
| (15, 'A', 'H', '0'), |
| (26, 'A', 'H', '0'), |
| (27, 'B', 'L', '0'), |
| (14, 'B', 'L', '0'), |
| (29, 'B', 'L', '0'), |
| (19, 'B', 'L', '0'), |
| (29, 'B', 'L', '0'), |
| (31, 'B', 'L', '0'), |
| (41, 'B', 'L', '0'), |
| (20, 'B', 'L', '1'), |
| (44, 'B', 'L', '1'), |
| (42, 'B', 'M', '1'), |
| (26, 'B', 'M', '1'), |
| (19, 'B', 'M', '1'), |
| (16, 'B', 'M', '1'), |
| (39, 'B', 'M', '1'), |
| (28, 'B', 'M', '1'), |
| (21, 'B', 'M', '1'), |
| (39, 'B', 'M', '1'), |
| (29, 'B', 'M', '1'), |
| (20, 'B', 'H', '1'), |
| (21, 'B', 'H', '1'), |
| (24, 'B', 'H', '1'), |
| (17, 'B', 'H', '1'), |
| (13, 'B', 'H', '1'), |
| (15, 'B', 'H', '1'), |
| (15, 'B', 'H', '1'), |
| (16, 'B', 'H', '1'), |
| (28, 'B', 'H', '1'); |
| |
| DROP TABLE IF EXISTS warpbreaks_dummy; |
| SELECT create_indicator_variables('warpbreaks', 'warpbreaks_dummy', 'wool,tension'); |
| |
| -- all assertion answers from R: |
| |
| -------------------------------------------------------------------------- |
| -- sqrt |
| -- glm(breaks~wool+tension, family=poisson(link=sqrt), data=warpbreaks) |
| SELECT glm('warpbreaks_dummy', |
| 'glm_model_sqrt', |
| 'breaks', |
| 'ARRAY[1.0,"wool_B","tension_M", "tension_H"]', |
| 'family=poisson, link=sqrt', NULL, 'max_iter=2'); |
| |