blob: 936aa08ac26caa08278e9b3864b013d5e3dc1522 [file] [log] [blame]
/* -----------------------------------------------------------------------------
* Test Multinomial Logistic Regression.
* -------------------------------------------------------------------------- */
/*
* The values given by the multinomial logistic regression were cross checked
* with the Matlab command mnrfit, which is documented at
* http://www.mathworks.com/help/toolbox/stats/mnrfit.html
*
* One important detail in the mnrfit command is that due to a difference in convention,
* its answers for the coefficients are the negative of our coefficient. Our
* convention is chosen to match the convention of the binary
* logistic regression implementation in madlib.
*
* For completeness, the matlab code needed to check the answers to the 'test3' example
* is included below. The code assumes that the data is contained in a csv file
* and that the columns haven't changed order. The coefficients will be in the
* 'B' variable.
*
* BEGIN CODE
*
data = csvread(csvFilename);
N = size(data, 1); % Number of records
J = size(data, 2)-1; % Number of covariates
% Integer encoded categories {0,1...K-1}
int_y = 1+data(:,end); % Categories
x = data(:,1:end-1); % Independant variables
% Pivot around the last data point
[B,dev,stats] = mnrfit(x,int_y)
*
* END CODE
*/
DROP TABLE IF EXISTS multinom_test;
CREATE TABLE multinom_test (
feat1 INTEGER,
feat2 INTEGER,
cat INTEGER,
g CHAR
);
INSERT INTO multinom_test(feat1, feat2, cat, g) VALUES
(1,35,1,'A'),
(2,33,0,'A'),
(3,39,1,'A'),
(1,37,1,'A'),
(2,31,1,'A'),
(3,36,0,'A'),
(2,36,1,'A'),
(2,31,1,'A'),
(2,41,1,'A'),
(2,37,1,'A'),
(1,44,1,'A'),
(3,33,2,'A'),
(1,31,1,'A'),
(2,44,1,'A'),
(1,35,1,'A'),
(1,44,0,'A'),
(1,46,0,'A'),
(2,46,1,'A'),
(2,46,2,'A'),
(3,49,1,'A'),
(2,39,0,'A'),
(2,44,1,'A'),
(1,47,1,'A'),
(1,44,1,'A'),
(1,37,2,'A'),
(3,38,2,'A'),
(1,49,0,'A'),
(2,44,0,'A'),
(1,41,2,'A'),
(1,50,2,'A'),
(2,44,0,'A'),
(1,39,1,'A'),
(1,40,2,'A'),
(1,46,2,'A'),
(2,41,1,'A'),
(2,39,1,'A'),
(2,33,1,'A'),
(3,59,2,'A'),
(1,41,0,'A'),
(2,47,2,'A'),
(2,31,0,'A'),
(3,42,2,'A'),
(1,55,2,'A'),
(3,40,1,'A'),
(1,44,2,'A'),
(1,54,1,'A'),
(2,46,1,'A'),
(1,54,0,'A'),
(2,42,1,'A'),
(2,49,2,'A'),
(2,41,2,'A'),
(2,41,1,'A'),
(1,44,0,'A'),
(1,57,2,'A'),
(2,52,2,'A'),
(1,49,0,'A'),
(3,41,2,'A'),
(3,57,0,'A'),
(1,62,1,'A'),
(3,33,0,'A'),
(2,54,1,'A'),
(2,40,2,'A'),
(3,52,2,'A'),
(2,57,1,'A'),
(2,49,1,'A'),
(2,46,1,'A'),
(1,57,0,'A'),
(2,49,2,'A'),
(2,52,2,'A'),
(2,53,0,'A'),
(3,54,2,'A'),
(2,57,2,'A'),
(3,41,2,'A'),
(1,52,0,'A'),
(2,57,1,'A'),
(1,54,0,'A'),
(2,52,1,'A'),
(2,52,0,'A'),
(2,44,0,'A'),
(2,46,2,'A'),
(1,49,1,'A'),
(2,54,2,'A'),
(3,52,2,'A'),
(1,44,0,'A'),
(3,49,1,'A'),
(1,46,2,'A'),
(2,54,0,'A'),
(2,39,0,'A'),
(2,59,0,'A'),
(2,45,1,'A'),
(3,52,1,'A'),
(3,54,0,'A'),
(3,44,1,'A'),
(2,50,2,'A'),
(2,62,1,'A'),
(2,59,0,'A'),
(2,52,2,'A'),
(2,52,1,'A'),
(2,46,1,'A'),
(2,41,0,'A'),
(2,52,2,'A'),
(2,52,1,'A'),
(2,55,1,'A'),
(2,41,1,'A'),
(2,49,0,'A'),
(1,59,2,'A'),
(1,54,0,'A'),
(2,54,0,'A'),
(2,59,2,'A'),
(2,55,2,'A'),
(1,62,2,'A'),
(2,54,2,'A'),
(2,54,2,'A'),
(2,54,2,'A'),
(2,59,2,'A'),
(2,57,1,'A'),
(3,61,2,'A'),
(3,52,2,'A'),
(2,59,2,'A'),
(2,62,2,'A'),
(1,60,1,'A'),
(2,59,2,'A'),
(2,65,2,'A'),
(3,61,2,'A'),
(2,59,2,'A'),
(3,59,2,'A'),
(2,59,2,'A'),
(2,59,2,'A'),
(2,65,2,'A'),
(3,57,2,'A'),
(2,59,2,'A'),
(3,49,2,'A'),
(1,49,0,'A'),
(3,59,2,'A'),
(2,62,2,'A'),
(3,59,0,'A'),
(2,54,2,'A'),
(3,63,2,'A'),
(1,43,2,'A'),
(3,54,2,'A'),
(3,52,2,'A'),
(1,57,2,'A'),
(2,57,0,'A'),
(2,57,0,'A'),
(2,61,2,'A'),
(2,62,0,'A'),
(2,62,0,'A'),
(1,65,0,'A'),
(2,57,2,'A'),
(3,59,2,'A'),
(2,59,2,'A'),
(3,62,2,'A'),
(2,65,2,'B'),
(2,62,1,'B'),
(1,62,0,'B'),
(2,62,2,'B'),
(3,54,2,'B'),
(3,62,2,'B'),
(1,65,2,'B'),
(3,62,2,'B'),
(3,67,0,'B'),
(3,65,0,'B'),
(1,60,2,'B'),
(3,59,2,'B'),
(2,59,2,'B'),
(2,59,1,'B'),
(3,65,0,'B'),
(3,62,2,'B'),
(3,65,2,'B'),
(3,59,0,'B'),
(1,59,0,'B'),
(3,61,2,'B'),
(1,65,2,'B'),
(3,67,1,'B'),
(3,65,2,'B'),
(1,65,2,'B'),
(2,67,2,'B'),
(1,65,2,'B'),
(1,62,2,'B'),
(3,52,2,'B'),
(3,63,2,'B'),
(2,59,2,'B'),
(3,65,2,'B'),
(2,59,0,'B'),
(3,67,2,'B'),
(3,67,2,'B'),
(3,60,2,'B'),
(3,67,2,'B'),
(3,62,2,'B'),
(2,54,2,'B'),
(3,65,2,'B'),
(3,62,2,'B'),
(2,59,2,'B'),
(3,60,2,'B'),
(3,63,2,'B'),
(3,65,2,'B'),
(2,63,1,'B'),
(2,67,2,'B'),
(2,65,2,'B'),
(2,62,2,'B');
SELECT multinom(
'multinom_test',
'mglm_test',
'cat',
'ARRAY[feat1, feat2, 1]',
'2',
'logit',
NULL,
NULL,
TRUE);
\x on
SELECT * FROM mglm_test;
SELECT * FROM mglm_test_summary;
\x off
SELECT assert(
coef IS NOT NULL AND
relative_error(coef, ARRAY[-.45117861,-.11221493,5.9900061]) < 1e-4 AND
relative_error(log_likelihood, -182.22) < 1e-4 AND
relative_error(std_err, ARRAY[.27290708,.02169723,1.2093327]) < 1e-4 AND
relative_error(z_stats, ARRAY[-1.6532316,-5.1718545,4.9531499]) < 1e-4 AND
relative_error(p_values, ARRAY[.09828373,2.318e-07,7.302e-07]) < 1e-2,
'Multinomial regression with IRLS optimizer (test) on category 1: Wrong results'
)
FROM mglm_test
WHERE category = '1';
-- for prediction
-- adding ID column for prediction
ALTER TABLE multinom_test ADD COLUMN i SERIAL;
DROP TABLE IF EXISTS mglm_predict;
SELECT multinom_predict(
'mglm_test',
'multinom_test',
'mglm_predict',
'response',
TRUE,
'i');
\x off
SELECT * FROM mglm_predict;
DROP TABLE IF EXISTS mglm_predict;
SELECT multinom_predict(
'mglm_test',
'multinom_test',
'mglm_predict',
'response',
TRUE,
'ARRAY[feat1, feat2, 1]');
\x off
SELECT * FROM mglm_predict;
DROP TABLE IF EXISTS mglm_predict;
SELECT multinom_predict(
'mglm_test',
'multinom_test',
'mglm_predict',
'prob',
TRUE,
'ARRAY[feat1, feat2, 1]');
\x off
SELECT * FROM mglm_predict;
-- for group
DROP TABLE IF EXISTS mglm_test, mglm_test_summary;
SELECT multinom(
'multinom_test',
'mglm_test',
'cat',
'ARRAY[1, feat1, feat2]',
'0',
'logit',
'g',
NULL,
TRUE);
-- for super quote
DROP TABLE IF EXISTS test3_quote;
CREATE TABLE test3_quote (
feat1 INTEGER,
feat2 INTEGER,
cat varchar
);
INSERT INTO test3_quote(feat1, feat2, cat) VALUES
(1,35,'1'),
(2,33,'0'),
(3,39,'1'),
(1,37,'1'),
(2,31,'1'),
(3,36,'0'),
(2,36,'2'),
(2,31,'2');
drop table if exists test3_quote_out, test3_quote_out_summary;
select multinom('test3_quote','test3_quote_out', 'case when cat = ''1'' then ''a'' else ''c'' end', 'ARRAY[1,feat1, feat2]');