blob: 8d66ee2ea76d2aa8d718e16bae3be9082c23aaf3 [file]
/* ----------------------------------------------------------------------- *//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ----------------------------------------------------------------------- */
/* -----------------------------------------------------------------------------
* Test Multinomial Logistic Regression.
* -------------------------------------------------------------------------- */
/*
* The values given by the multinomial logistic regression were cross checked
* with the Matlab command mnrfit, which is documented at
* http://www.mathworks.com/help/toolbox/stats/mnrfit.html
*
* One important detail in the mnrfit command is that due to a difference in convention,
* its answers for the coefficients are the negative of our coefficient. Our
* convention is chosen to match the convention of the binary
* logistic regression implementation in madlib.
*
* For completeness, the matlab code needed to check the answers to the 'test3' example
* is included below. The code assumes that the data is contained in a csv file
* and that the columns haven't changed order. The coefficients will be in the
* 'B' variable.
*
* BEGIN CODE
*
data = csvread(csvFilename);
N = size(data, 1); % Number of records
J = size(data, 2)-1; % Number of covariates
% Integer encoded categories {0,1...K-1}
int_y = 1+data(:,end); % Categories
x = data(:,1:end-1); % Independant variables
% Pivot around the last data point
[B,dev,stats] = mnrfit(x,int_y)
*
* END CODE
*/
DROP TABLE IF EXISTS multinom_test;
CREATE TABLE multinom_test (
feat1 INTEGER,
feat2 INTEGER,
cat INTEGER,
g CHAR
);
INSERT INTO multinom_test(feat1, feat2, cat, g) VALUES
(1,35,1,'A'),
(2,33,0,'A'),
(3,39,1,'A'),
(1,37,1,'A'),
(2,31,1,'A'),
(3,36,0,'A'),
(2,36,1,'A'),
(2,31,1,'A'),
(2,41,1,'A'),
(2,37,1,'A'),
(1,44,1,'A'),
(3,33,2,'A'),
(1,31,1,'A'),
(2,44,1,'A'),
(1,35,1,'A'),
(1,44,0,'A'),
(1,46,0,'A'),
(2,46,1,'A'),
(2,46,2,'A'),
(3,49,1,'A'),
(2,39,0,'A'),
(2,44,1,'A'),
(1,47,1,'A'),
(1,44,1,'A'),
(1,37,2,'A'),
(3,38,2,'A'),
(1,49,0,'A'),
(2,44,0,'A'),
(1,41,2,'A'),
(1,50,2,'A'),
(2,44,0,'A'),
(1,39,1,'A'),
(1,40,2,'A'),
(1,46,2,'A'),
(2,41,1,'A'),
(2,39,1,'A'),
(2,33,1,'A'),
(3,59,2,'A'),
(1,41,0,'A'),
(2,47,2,'A'),
(2,31,0,'A'),
(3,42,2,'A'),
(1,55,2,'A'),
(3,40,1,'A'),
(1,44,2,'A'),
(1,54,1,'A'),
(2,46,1,'A'),
(1,54,0,'A'),
(2,42,1,'A'),
(2,49,2,'A'),
(2,41,2,'A'),
(2,41,1,'A'),
(1,44,0,'A'),
(1,57,2,'A'),
(2,52,2,'A'),
(1,49,0,'A'),
(3,41,2,'A'),
(3,57,0,'A'),
(1,62,1,'A'),
(3,33,0,'A'),
(2,54,1,'A'),
(2,40,2,'A'),
(3,52,2,'A'),
(2,57,1,'A'),
(2,49,1,'A'),
(2,46,1,'A'),
(1,57,0,'A'),
(2,49,2,'A'),
(2,52,2,'A'),
(2,53,0,'A'),
(3,54,2,'A'),
(2,57,2,'A'),
(3,41,2,'A'),
(1,52,0,'A'),
(2,57,1,'A'),
(1,54,0,'A'),
(2,52,1,'A'),
(2,52,0,'A'),
(2,44,0,'A'),
(2,46,2,'A'),
(1,49,1,'A'),
(2,54,2,'A'),
(3,52,2,'A');
SELECT multinom(
'multinom_test',
'mglm_test',
'cat',
'ARRAY[feat1, feat2, 1]',
'2',
'logit',
NULL,
NULL,
TRUE);
SELECT ordinal(
'multinom_test',
'ordinal_out',
'cat',
'ARRAY[feat1, feat2]',
'0<1<2',
'logit'
);
DROP TABLE IF EXISTS abalone CASCADE;
CREATE TABLE abalone (
id integer,
sex text,
length double precision,
diameter double precision,
height double precision,
whole double precision,
shucked double precision,
viscera double precision,
shell double precision,
rings integer
);
INSERT INTO abalone VALUES
(3151, 'F', 0.655000000000000027, 0.505000000000000004, 0.165000000000000008, 1.36699999999999999, 0.583500000000000019, 0.351499999999999979, 0.396000000000000019, 10),
(2026, 'F', 0.550000000000000044, 0.469999999999999973, 0.149999999999999994, 0.920499999999999985, 0.381000000000000005, 0.243499999999999994, 0.267500000000000016, 10),
(3751, 'I', 0.434999999999999998, 0.375, 0.110000000000000001, 0.41549999999999998, 0.170000000000000012, 0.0759999999999999981, 0.14499999999999999, 8),
(720, 'I', 0.149999999999999994, 0.100000000000000006, 0.0250000000000000014, 0.0149999999999999994, 0.00449999999999999966, 0.00400000000000000008, 0.0050000000000000001, 2),
(1635, 'F', 0.574999999999999956, 0.469999999999999973, 0.154999999999999999, 1.1160000000000001, 0.509000000000000008, 0.237999999999999989, 0.340000000000000024, 10),
(2648, 'I', 0.5, 0.390000000000000013, 0.125, 0.582999999999999963, 0.293999999999999984, 0.132000000000000006, 0.160500000000000004, 8),
(1796, 'F', 0.57999999999999996, 0.429999999999999993, 0.170000000000000012, 1.47999999999999998, 0.65349999999999997, 0.32400000000000001, 0.41549999999999998, 10),
(209, 'F', 0.525000000000000022, 0.41499999999999998, 0.170000000000000012, 0.832500000000000018, 0.275500000000000023, 0.168500000000000011, 0.309999999999999998, 13),
(1451, 'I', 0.455000000000000016, 0.33500000000000002, 0.135000000000000009, 0.501000000000000001, 0.274000000000000021, 0.0995000000000000051, 0.106499999999999997, 7),
(1108, 'I', 0.510000000000000009, 0.380000000000000004, 0.115000000000000005, 0.515499999999999958, 0.214999999999999997, 0.113500000000000004, 0.166000000000000009, 8),
(3675, 'F', 0.594999999999999973, 0.450000000000000011, 0.165000000000000008, 1.08099999999999996, 0.489999999999999991, 0.252500000000000002, 0.279000000000000026, 12),
(2108, 'F', 0.675000000000000044, 0.550000000000000044, 0.179999999999999993, 1.68849999999999989, 0.562000000000000055, 0.370499999999999996, 0.599999999999999978, 15),
(3312, 'F', 0.479999999999999982, 0.380000000000000004, 0.135000000000000009, 0.507000000000000006, 0.191500000000000004, 0.13650000000000001, 0.154999999999999999, 12),
(882, 'M', 0.655000000000000027, 0.520000000000000018, 0.165000000000000008, 1.40949999999999998, 0.585999999999999965, 0.290999999999999981, 0.405000000000000027, 9),
(3402, 'M', 0.479999999999999982, 0.395000000000000018, 0.149999999999999994, 0.681499999999999995, 0.214499999999999996, 0.140500000000000014, 0.2495, 18),
(829, 'I', 0.409999999999999976, 0.325000000000000011, 0.100000000000000006, 0.394000000000000017, 0.20799999999999999, 0.0655000000000000027, 0.105999999999999997, 6),
(1305, 'M', 0.535000000000000031, 0.434999999999999998, 0.149999999999999994, 0.716999999999999971, 0.347499999999999976, 0.14449999999999999, 0.194000000000000006, 9),
(3613, 'M', 0.599999999999999978, 0.46000000000000002, 0.179999999999999993, 1.1399999999999999, 0.422999999999999987, 0.257500000000000007, 0.364999999999999991, 10),
(1068, 'I', 0.340000000000000024, 0.265000000000000013, 0.0800000000000000017, 0.201500000000000012, 0.0899999999999999967, 0.0475000000000000006, 0.0550000000000000003, 5),
(2446, 'M', 0.5, 0.380000000000000004, 0.135000000000000009, 0.583500000000000019, 0.22950000000000001, 0.126500000000000001, 0.179999999999999993, 12),
(1393, 'M', 0.635000000000000009, 0.474999999999999978, 0.170000000000000012, 1.19350000000000001, 0.520499999999999963, 0.269500000000000017, 0.366499999999999992, 10),
(359, 'M', 0.744999999999999996, 0.584999999999999964, 0.214999999999999997, 2.49900000000000011, 0.92649999999999999, 0.471999999999999975, 0.699999999999999956, 17),
(549, 'F', 0.564999999999999947, 0.450000000000000011, 0.160000000000000003, 0.79500000000000004, 0.360499999999999987, 0.155499999999999999, 0.23000000000000001, 12),
(1154, 'F', 0.599999999999999978, 0.474999999999999978, 0.160000000000000003, 1.02649999999999997, 0.484999999999999987, 0.2495, 0.256500000000000006, 9),
(1790, 'F', 0.54500000000000004, 0.385000000000000009, 0.149999999999999994, 1.11850000000000005, 0.542499999999999982, 0.244499999999999995, 0.284499999999999975, 9),
(3703, 'F', 0.665000000000000036, 0.540000000000000036, 0.195000000000000007, 1.76400000000000001, 0.850500000000000034, 0.361499999999999988, 0.469999999999999973, 11),
(1962, 'F', 0.655000000000000027, 0.515000000000000013, 0.179999999999999993, 1.41199999999999992, 0.619500000000000051, 0.248499999999999999, 0.496999999999999997, 11),
(1665, 'I', 0.604999999999999982, 0.469999999999999973, 0.14499999999999999, 0.802499999999999991, 0.379000000000000004, 0.226500000000000007, 0.220000000000000001, 9),
(635, 'M', 0.359999999999999987, 0.294999999999999984, 0.100000000000000006, 0.210499999999999993, 0.0660000000000000031, 0.0524999999999999981, 0.0749999999999999972, 9),
(3901, 'M', 0.445000000000000007, 0.344999999999999973, 0.140000000000000013, 0.475999999999999979, 0.205499999999999988, 0.101500000000000007, 0.108499999999999999, 15),
(2734, 'I', 0.41499999999999998, 0.33500000000000002, 0.100000000000000006, 0.357999999999999985, 0.169000000000000011, 0.067000000000000004, 0.104999999999999996, 7),
(3856, 'M', 0.409999999999999976, 0.33500000000000002, 0.115000000000000005, 0.440500000000000003, 0.190000000000000002, 0.0850000000000000061, 0.135000000000000009, 8),
(827, 'I', 0.395000000000000018, 0.28999999999999998, 0.0950000000000000011, 0.303999999999999992, 0.127000000000000002, 0.0840000000000000052, 0.076999999999999999, 6),
(3381, 'I', 0.190000000000000002, 0.130000000000000004, 0.0449999999999999983, 0.0264999999999999993, 0.00899999999999999932, 0.0050000000000000001, 0.00899999999999999932, 5),
(3972, 'I', 0.400000000000000022, 0.294999999999999984, 0.0950000000000000011, 0.252000000000000002, 0.110500000000000001, 0.0575000000000000025, 0.0660000000000000031, 6),
(1155, 'M', 0.599999999999999978, 0.455000000000000016, 0.170000000000000012, 1.1915, 0.695999999999999952, 0.239499999999999991, 0.239999999999999991, 8),
(3467, 'M', 0.640000000000000013, 0.5, 0.170000000000000012, 1.4544999999999999, 0.642000000000000015, 0.357499999999999984, 0.353999999999999981, 9),
(2433, 'F', 0.609999999999999987, 0.484999999999999987, 0.165000000000000008, 1.08699999999999997, 0.425499999999999989, 0.232000000000000012, 0.380000000000000004, 11),
(552, 'I', 0.614999999999999991, 0.489999999999999991, 0.154999999999999999, 0.988500000000000045, 0.41449999999999998, 0.195000000000000007, 0.344999999999999973, 13),
(1425, 'F', 0.729999999999999982, 0.57999999999999996, 0.190000000000000002, 1.73750000000000004, 0.678499999999999992, 0.434499999999999997, 0.520000000000000018, 11),
(2402, 'F', 0.584999999999999964, 0.41499999999999998, 0.154999999999999999, 0.69850000000000001, 0.299999999999999989, 0.145999999999999991, 0.195000000000000007, 12),
(1748, 'F', 0.699999999999999956, 0.535000000000000031, 0.174999999999999989, 1.77299999999999991, 0.680499999999999994, 0.479999999999999982, 0.512000000000000011, 15),
(3983, 'I', 0.57999999999999996, 0.434999999999999998, 0.149999999999999994, 0.891499999999999959, 0.362999999999999989, 0.192500000000000004, 0.251500000000000001, 6),
(335, 'F', 0.739999999999999991, 0.599999999999999978, 0.195000000000000007, 1.97399999999999998, 0.597999999999999976, 0.408499999999999974, 0.709999999999999964, 16),
(1587, 'I', 0.515000000000000013, 0.349999999999999978, 0.104999999999999996, 0.474499999999999977, 0.212999999999999995, 0.122999999999999998, 0.127500000000000002, 10),
(2448, 'I', 0.275000000000000022, 0.204999999999999988, 0.0800000000000000017, 0.096000000000000002, 0.0359999999999999973, 0.0184999999999999991, 0.0299999999999999989, 6),
(1362, 'F', 0.604999999999999982, 0.474999999999999978, 0.174999999999999989, 1.07600000000000007, 0.463000000000000023, 0.219500000000000001, 0.33500000000000002, 9),
(2799, 'M', 0.640000000000000013, 0.484999999999999987, 0.149999999999999994, 1.09800000000000009, 0.519499999999999962, 0.222000000000000003, 0.317500000000000004, 10),
(1413, 'F', 0.67000000000000004, 0.505000000000000004, 0.174999999999999989, 1.01449999999999996, 0.4375, 0.271000000000000019, 0.3745, 10),
(1739, 'F', 0.67000000000000004, 0.540000000000000036, 0.195000000000000007, 1.61899999999999999, 0.739999999999999991, 0.330500000000000016, 0.465000000000000024, 11),
(1152, 'M', 0.584999999999999964, 0.465000000000000024, 0.160000000000000003, 0.955500000000000016, 0.45950000000000002, 0.235999999999999988, 0.265000000000000013, 7),
(2427, 'I', 0.564999999999999947, 0.434999999999999998, 0.154999999999999999, 0.782000000000000028, 0.271500000000000019, 0.16800000000000001, 0.284999999999999976, 14),
(1777, 'M', 0.484999999999999987, 0.369999999999999996, 0.154999999999999999, 0.967999999999999972, 0.418999999999999984, 0.245499999999999996, 0.236499999999999988, 9),
(3294, 'M', 0.574999999999999956, 0.455000000000000016, 0.184999999999999998, 1.15599999999999992, 0.552499999999999991, 0.242999999999999994, 0.294999999999999984, 13),
(1403, 'M', 0.650000000000000022, 0.510000000000000009, 0.190000000000000002, 1.54200000000000004, 0.715500000000000025, 0.373499999999999999, 0.375, 9),
(2256, 'M', 0.510000000000000009, 0.395000000000000018, 0.14499999999999999, 0.61850000000000005, 0.215999999999999998, 0.138500000000000012, 0.239999999999999991, 12),
(3984, 'F', 0.584999999999999964, 0.450000000000000011, 0.125, 0.873999999999999999, 0.354499999999999982, 0.20749999999999999, 0.225000000000000006, 6),
(1116, 'M', 0.525000000000000022, 0.405000000000000027, 0.119999999999999996, 0.755499999999999949, 0.3755, 0.155499999999999999, 0.201000000000000012, 9),
(1366, 'M', 0.609999999999999987, 0.474999999999999978, 0.170000000000000012, 1.02649999999999997, 0.434999999999999998, 0.233500000000000013, 0.303499999999999992, 10),
(3759, 'I', 0.525000000000000022, 0.400000000000000022, 0.140000000000000013, 0.605500000000000038, 0.260500000000000009, 0.107999999999999999, 0.209999999999999992, 9);
DROP TABLE IF EXISTS abalone_out, abalone_out_summary;
SELECT glm(
'abalone',
'abalone_out',
'rings',
'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]',
'family=inverse_gaussian, link=sqr_inverse', NULL, 'max_iter=2, tolerance=1e-16'
);
DROP TABLE IF EXISTS abalone_out, abalone_out_summary;
SELECT glm(
'abalone',
'abalone_out',
'rings',
'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]',
'family=gaussian, link=identity', NULL, 'max_iter=2'
);
DROP TABLE IF EXISTS abalone_out, abalone_out_summary;
SELECT glm(
'abalone',
'abalone_out',
'rings',
'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]',
'family=gamma, link=inverse', NULL, 'max_iter=2, tolerance=1e-16'
);
DROP TABLE IF EXISTS abalone_probit_out, abalone_probit_out_summary;
SELECT glm(
'abalone',
'abalone_probit_out',
'rings < 10',
'ARRAY[1, length, diameter, height, whole, shucked, viscera, shell]',
'family=binomial, link=probit', NULL, 'max_iter=2, tolerance=1e-16'
);
DROP TABLE IF EXISTS warpbreaks CASCADE;
CREATE TABLE warpbreaks(
id serial,
breaks integer,
wool char(1),
tension char(1),
g char(1)
);
INSERT INTO warpbreaks(breaks, wool, tension, g) VALUES
(26, 'A', 'L', '1'),
(30, 'A', 'L', '1'),
(54, 'A', 'L', '1'),
(25, 'A', 'L', '1'),
(70, 'A', 'L', '1'),
(52, 'A', 'L', '1'),
(51, 'A', 'L', '1'),
(26, 'A', 'L', '1'),
(67, 'A', 'L', '1'),
(18, 'A', 'M', '1'),
(21, 'A', 'M', '1'),
(29, 'A', 'M', '1'),
(17, 'A', 'M', '1'),
(12, 'A', 'M', '1'),
(18, 'A', 'M', '1'),
(35, 'A', 'M', '1'),
(30, 'A', 'M', '1'),
(36, 'A', 'M', '1'),
(36, 'A', 'H', '0'),
(21, 'A', 'H', '0'),
(24, 'A', 'H', '0'),
(18, 'A', 'H', '0'),
(10, 'A', 'H', '0'),
(43, 'A', 'H', '0'),
(28, 'A', 'H', '0'),
(15, 'A', 'H', '0'),
(26, 'A', 'H', '0'),
(27, 'B', 'L', '0'),
(14, 'B', 'L', '0'),
(29, 'B', 'L', '0'),
(19, 'B', 'L', '0'),
(29, 'B', 'L', '0'),
(31, 'B', 'L', '0'),
(41, 'B', 'L', '0'),
(20, 'B', 'L', '1'),
(44, 'B', 'L', '1'),
(42, 'B', 'M', '1'),
(26, 'B', 'M', '1'),
(19, 'B', 'M', '1'),
(16, 'B', 'M', '1'),
(39, 'B', 'M', '1'),
(28, 'B', 'M', '1'),
(21, 'B', 'M', '1'),
(39, 'B', 'M', '1'),
(29, 'B', 'M', '1'),
(20, 'B', 'H', '1'),
(21, 'B', 'H', '1'),
(24, 'B', 'H', '1'),
(17, 'B', 'H', '1'),
(13, 'B', 'H', '1'),
(15, 'B', 'H', '1'),
(15, 'B', 'H', '1'),
(16, 'B', 'H', '1'),
(28, 'B', 'H', '1');
DROP TABLE IF EXISTS warpbreaks_dummy;
SELECT create_indicator_variables('warpbreaks', 'warpbreaks_dummy', 'wool,tension');
-- all assertion answers from R:
--------------------------------------------------------------------------
-- sqrt
-- glm(breaks~wool+tension, family=poisson(link=sqrt), data=warpbreaks)
SELECT glm('warpbreaks_dummy',
'glm_model_sqrt',
'breaks',
'ARRAY[1.0,"wool_B","tension_M", "tension_H"]',
'family=poisson, link=sqrt', NULL, 'max_iter=2');