| /* ----------------------------------------------------------------------------- |
| * Test Multinomial Logistic Regression. |
| * -------------------------------------------------------------------------- */ |
| |
| /* |
| * The values given by the multinomial logistic regression were cross checked |
| * with the Matlab command mnrfit, which is documented at |
| * http://www.mathworks.com/help/toolbox/stats/mnrfit.html |
| * |
| * One important detail in the mnrfit command is that due to a difference in convention, |
| * its answers for the coefficients are the negative of our coefficient. Our |
| * convention is chosen to match the convention of the binary |
| * logistic regression implementation in madlib. |
| * |
| * For completeness, the matlab code needed to check the answers to the 'test3' example |
| * is included below. The code assumes that the data is contained in a csv file |
| * and that the columns haven't changed order. The coefficients will be in the |
| * 'B' variable. |
| * |
| * BEGIN CODE |
| * |
| data = csvread(csvFilename); |
| N = size(data, 1); % Number of records |
| J = size(data, 2)-1; % Number of covariates |
| |
| % Integer encoded categories {0,1...K-1} |
| int_y = 1+data(:,end); % Categories |
| x = data(:,1:end-1); % Independant variables |
| |
| % Pivot around the last data point |
| [B,dev,stats] = mnrfit(x,int_y) |
| * |
| * END CODE |
| */ |
| |
| DROP TABLE IF EXISTS multinom_test; |
| |
| CREATE TABLE multinom_test ( |
| feat1 INTEGER, |
| feat2 INTEGER, |
| cat INTEGER, |
| g CHAR |
| ); |
| |
| INSERT INTO multinom_test(feat1, feat2, cat, g) VALUES |
| (1,35,1,'A'), |
| (2,33,0,'A'), |
| (3,39,1,'A'), |
| (1,37,1,'A'), |
| (2,31,1,'A'), |
| (3,36,0,'A'), |
| (2,36,1,'A'), |
| (2,31,1,'A'), |
| (2,41,1,'A'), |
| (2,37,1,'A'), |
| (1,44,1,'A'), |
| (3,33,2,'A'), |
| (1,31,1,'A'), |
| (2,44,1,'A'), |
| (1,35,1,'A'), |
| (1,44,0,'A'), |
| (1,46,0,'A'), |
| (2,46,1,'A'), |
| (2,46,2,'A'), |
| (3,49,1,'A'), |
| (2,39,0,'A'), |
| (2,44,1,'A'), |
| (1,47,1,'A'), |
| (1,44,1,'A'), |
| (1,37,2,'A'), |
| (3,38,2,'A'), |
| (1,49,0,'A'), |
| (2,44,0,'A'), |
| (1,41,2,'A'), |
| (1,50,2,'A'), |
| (2,44,0,'A'), |
| (1,39,1,'A'), |
| (1,40,2,'A'), |
| (1,46,2,'A'), |
| (2,41,1,'A'), |
| (2,39,1,'A'), |
| (2,33,1,'A'), |
| (3,59,2,'A'), |
| (1,41,0,'A'), |
| (2,47,2,'A'), |
| (2,31,0,'A'), |
| (3,42,2,'A'), |
| (1,55,2,'A'), |
| (3,40,1,'A'), |
| (1,44,2,'A'), |
| (1,54,1,'A'), |
| (2,46,1,'A'), |
| (1,54,0,'A'), |
| (2,42,1,'A'), |
| (2,49,2,'A'), |
| (2,41,2,'A'), |
| (2,41,1,'A'), |
| (1,44,0,'A'), |
| (1,57,2,'A'), |
| (2,52,2,'A'), |
| (1,49,0,'A'), |
| (3,41,2,'A'), |
| (3,57,0,'A'), |
| (1,62,1,'A'), |
| (3,33,0,'A'), |
| (2,54,1,'A'), |
| (2,40,2,'A'), |
| (3,52,2,'A'), |
| (2,57,1,'A'), |
| (2,49,1,'A'), |
| (2,46,1,'A'), |
| (1,57,0,'A'), |
| (2,49,2,'A'), |
| (2,52,2,'A'), |
| (2,53,0,'A'), |
| (3,54,2,'A'), |
| (2,57,2,'A'), |
| (3,41,2,'A'), |
| (1,52,0,'A'), |
| (2,57,1,'A'), |
| (1,54,0,'A'), |
| (2,52,1,'A'), |
| (2,52,0,'A'), |
| (2,44,0,'A'), |
| (2,46,2,'A'), |
| (1,49,1,'A'), |
| (2,54,2,'A'), |
| (3,52,2,'A'), |
| (1,44,0,'A'), |
| (3,49,1,'A'), |
| (1,46,2,'A'), |
| (2,54,0,'A'), |
| (2,39,0,'A'), |
| (2,59,0,'A'), |
| (2,45,1,'A'), |
| (3,52,1,'A'), |
| (3,54,0,'A'), |
| (3,44,1,'A'), |
| (2,50,2,'A'), |
| (2,62,1,'A'), |
| (2,59,0,'A'), |
| (2,52,2,'A'), |
| (2,52,1,'A'), |
| (2,46,1,'A'), |
| (2,41,0,'A'), |
| (2,52,2,'A'), |
| (2,52,1,'A'), |
| (2,55,1,'A'), |
| (2,41,1,'A'), |
| (2,49,0,'A'), |
| (1,59,2,'A'), |
| (1,54,0,'A'), |
| (2,54,0,'A'), |
| (2,59,2,'A'), |
| (2,55,2,'A'), |
| (1,62,2,'A'), |
| (2,54,2,'A'), |
| (2,54,2,'A'), |
| (2,54,2,'A'), |
| (2,59,2,'A'), |
| (2,57,1,'A'), |
| (3,61,2,'A'), |
| (3,52,2,'A'), |
| (2,59,2,'A'), |
| (2,62,2,'A'), |
| (1,60,1,'A'), |
| (2,59,2,'A'), |
| (2,65,2,'A'), |
| (3,61,2,'A'), |
| (2,59,2,'A'), |
| (3,59,2,'A'), |
| (2,59,2,'A'), |
| (2,59,2,'A'), |
| (2,65,2,'A'), |
| (3,57,2,'A'), |
| (2,59,2,'A'), |
| (3,49,2,'A'), |
| (1,49,0,'A'), |
| (3,59,2,'A'), |
| (2,62,2,'A'), |
| (3,59,0,'A'), |
| (2,54,2,'A'), |
| (3,63,2,'A'), |
| (1,43,2,'A'), |
| (3,54,2,'A'), |
| (3,52,2,'A'), |
| (1,57,2,'A'), |
| (2,57,0,'A'), |
| (2,57,0,'A'), |
| (2,61,2,'A'), |
| (2,62,0,'A'), |
| (2,62,0,'A'), |
| (1,65,0,'A'), |
| (2,57,2,'A'), |
| (3,59,2,'A'), |
| (2,59,2,'A'), |
| (3,62,2,'A'), |
| (2,65,2,'B'), |
| (2,62,1,'B'), |
| (1,62,0,'B'), |
| (2,62,2,'B'), |
| (3,54,2,'B'), |
| (3,62,2,'B'), |
| (1,65,2,'B'), |
| (3,62,2,'B'), |
| (3,67,0,'B'), |
| (3,65,0,'B'), |
| (1,60,2,'B'), |
| (3,59,2,'B'), |
| (2,59,2,'B'), |
| (2,59,1,'B'), |
| (3,65,0,'B'), |
| (3,62,2,'B'), |
| (3,65,2,'B'), |
| (3,59,0,'B'), |
| (1,59,0,'B'), |
| (3,61,2,'B'), |
| (1,65,2,'B'), |
| (3,67,1,'B'), |
| (3,65,2,'B'), |
| (1,65,2,'B'), |
| (2,67,2,'B'), |
| (1,65,2,'B'), |
| (1,62,2,'B'), |
| (3,52,2,'B'), |
| (3,63,2,'B'), |
| (2,59,2,'B'), |
| (3,65,2,'B'), |
| (2,59,0,'B'), |
| (3,67,2,'B'), |
| (3,67,2,'B'), |
| (3,60,2,'B'), |
| (3,67,2,'B'), |
| (3,62,2,'B'), |
| (2,54,2,'B'), |
| (3,65,2,'B'), |
| (3,62,2,'B'), |
| (2,59,2,'B'), |
| (3,60,2,'B'), |
| (3,63,2,'B'), |
| (3,65,2,'B'), |
| (2,63,1,'B'), |
| (2,67,2,'B'), |
| (2,65,2,'B'), |
| (2,62,2,'B'); |
| |
| SELECT multinom( |
| 'multinom_test', |
| 'mglm_test', |
| 'cat', |
| 'ARRAY[feat1, feat2, 1]', |
| '2', |
| 'logit', |
| NULL, |
| NULL, |
| TRUE); |
| |
| \x on |
| SELECT * FROM mglm_test; |
| SELECT * FROM mglm_test_summary; |
| \x off |
| |
| SELECT assert( |
| coef IS NOT NULL AND |
| relative_error(coef, ARRAY[-.45117861,-.11221493,5.9900061]) < 1e-4 AND |
| relative_error(log_likelihood, -182.22) < 1e-4 AND |
| relative_error(std_err, ARRAY[.27290708,.02169723,1.2093327]) < 1e-4 AND |
| relative_error(z_stats, ARRAY[-1.6532316,-5.1718545,4.9531499]) < 1e-4 AND |
| relative_error(p_values, ARRAY[.09828373,2.318e-07,7.302e-07]) < 1e-2, |
| 'Multinomial regression with IRLS optimizer (test) on category 1: Wrong results' |
| ) |
| FROM mglm_test |
| WHERE category = '1'; |
| |
| -- for prediction |
| -- adding ID column for prediction |
| ALTER TABLE multinom_test ADD COLUMN i SERIAL; |
| |
| DROP TABLE IF EXISTS mglm_predict; |
| SELECT multinom_predict( |
| 'mglm_test', |
| 'multinom_test', |
| 'mglm_predict', |
| 'response', |
| TRUE, |
| 'i'); |
| \x off |
| SELECT * FROM mglm_predict; |
| |
| DROP TABLE IF EXISTS mglm_predict; |
| SELECT multinom_predict( |
| 'mglm_test', |
| 'multinom_test', |
| 'mglm_predict', |
| 'response', |
| TRUE, |
| 'ARRAY[feat1, feat2, 1]'); |
| \x off |
| SELECT * FROM mglm_predict; |
| |
| DROP TABLE IF EXISTS mglm_predict; |
| SELECT multinom_predict( |
| 'mglm_test', |
| 'multinom_test', |
| 'mglm_predict', |
| 'prob', |
| TRUE, |
| 'ARRAY[feat1, feat2, 1]'); |
| \x off |
| SELECT * FROM mglm_predict; |
| |
| -- for group |
| DROP TABLE IF EXISTS mglm_test, mglm_test_summary; |
| SELECT multinom( |
| 'multinom_test', |
| 'mglm_test', |
| 'cat', |
| 'ARRAY[1, feat1, feat2]', |
| '0', |
| 'logit', |
| 'g', |
| NULL, |
| TRUE); |
| |
| -- for super quote |
| DROP TABLE IF EXISTS test3_quote; |
| CREATE TABLE test3_quote ( |
| feat1 INTEGER, |
| feat2 INTEGER, |
| cat varchar |
| ); |
| INSERT INTO test3_quote(feat1, feat2, cat) VALUES |
| (1,35,'1'), |
| (2,33,'0'), |
| (3,39,'1'), |
| (1,37,'1'), |
| (2,31,'1'), |
| (3,36,'0'), |
| (2,36,'2'), |
| (2,31,'2'); |
| |
| drop table if exists test3_quote_out, test3_quote_out_summary; |
| select multinom('test3_quote','test3_quote_out', 'case when cat = ''1'' then ''a'' else ''c'' end', 'ARRAY[1,feat1, feat2]'); |