blob: fa6fc7dd2428aa5e917db13c5ccf0aa65182a496 [file] [log] [blame]
m4_include(`SQLCommon.m4')
/* -----------------------------------------------------------------------------
* Test Logistic Regression.
* -------------------------------------------------------------------------- */
/*
* The following example is taken from:
* http://luna.cas.usf.edu/~mbrannic/files/regression/Logistic.html
* Predicting heart attack
*/
DROP TABLE IF EXISTS patients;
CREATE TABLE patients (
id INTEGER NOT NULL,
second_attack INTEGER,
treatment INTEGER,
trait_anxiety INTEGER
)m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)');
INSERT INTO patients(ID, second_attack, treatment, trait_anxiety) VALUES
( 0, NULL, 1, 70),
( 1, 1, 1, 70),
( 2, 1, 1, 80),
( 3, 1, 1, 50),
( 4, 1, 0, 60),
( 5, 1, 0, 40),
( 6, 1, 0, 65),
( 7, 1, 0, 75),
( 8, 1, 0, 80),
( 9, 1, 0, 70),
(10, 1, 0, 60),
(11, 0, 1, 65),
(12, 0, 1, 50),
(13, 0, 1, 45),
(14, 0, 1, 35),
(15, 0, 1, 40),
(16, 0, 1, 50),
(17, 0, 0, 55),
(18, 0, 0, 45),
(19, 0, 0, 50),
(20, 0, 0, 60),
(21, 0, 0, NULL),
(22, 0, NULL, 60);
CREATE VIEW patients_view AS
(SELECT * FROM patients) ;
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'patients',
'temp_result',
'second_attack',
'ARRAY[1, treatment, trait_anxiety]',
Null,
20,
'irls'
);
-- The coefficients are from the source above, the other values have been
-- computed with the IRLS optimizer in MADlib
SELECT assert(
relative_error(coef, ARRAY[-6.36, -1.02, 0.119]) < 1e-3 AND
relative_error(log_likelihood, -9.41) < 1e-3 AND
relative_error(std_err, ARRAY[3.21, 1.17, 0.0550]) < 0.002 AND
relative_error(z_stats, ARRAY[-1.98, -0.874, 2.17]) < 0.002 AND
relative_error(p_values, ARRAY[0.0477, 0.382, 0.0304]) < 1e-3 AND
relative_error(odds_ratios, ARRAY[0.00172, 0.359, 1.13]) < 0.004 AND
relative_error(condition_no, sqrt(106329)) < 1e-2 AND
num_rows_processed = 20 AND
num_missing_rows_skipped = 3,
'Logistic regression with IRLS optimizer (patients test): Wrong results'
) FROM temp_result;
-- Run the same test as above, but use the view of patients.
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'patients_view',
'temp_result',
'second_attack',
'ARRAY[1, treatment, trait_anxiety]',
Null,
20,
'irls'
);
-- The coefficients are from the source above, the other values have been
-- computed with the IRLS optimizer in MADlib
SELECT assert(
relative_error(coef, ARRAY[-6.36, -1.02, 0.119]) < 1e-3 AND
relative_error(log_likelihood, -9.41) < 1e-3 AND
relative_error(std_err, ARRAY[3.21, 1.17, 0.0550]) < 0.002 AND
relative_error(z_stats, ARRAY[-1.98, -0.874, 2.17]) < 0.002 AND
relative_error(p_values, ARRAY[0.0477, 0.382, 0.0304]) < 1e-3 AND
relative_error(odds_ratios, ARRAY[0.00172, 0.359, 1.13]) < 0.004 AND
relative_error(condition_no, sqrt(106329)) < 1e-2 AND
num_rows_processed = 20 AND
num_missing_rows_skipped = 3,
'Logistic regression with IRLS optimizer (patients test): Wrong results'
) FROM temp_result;
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'patients',
'temp_result',
'second_attack',
'ARRAY[1, treatment, trait_anxiety]',
Null,
200,
'cg',
0
);
-- We are pretty generous here
SELECT assert(
relative_error(coef, ARRAY[-6.36, -1.02, 0.119]) < 0.04 AND
relative_error(log_likelihood, -9.41) < 1e-3 AND
relative_error(std_err, ARRAY[3.21, 1.17, 0.0550]) < 0.02 AND
relative_error(z_stats, ARRAY[-1.98, -0.874, 2.17]) < 0.02 AND
relative_error(p_values, ARRAY[0.0477, 0.382, 0.0304]) < 0.03 AND
relative_error(odds_ratios, ARRAY[0.00172, 0.359, 1.13]) < 0.01 AND
relative_error(condition_no, sqrt(106329)) < 0.02 AND
num_rows_processed = 20 AND
num_missing_rows_skipped = 3,
'Logistic regression (patients test): Wrong results'
)
FROM temp_result;
-- IGD performs poorly on this instance, so we are not testing it
/*
* The following example is taken from:
* http://www.ats.ucla.edu/stat/stata/output/old/lognoframe.htm
* Crime data
*
* Description
* These data are crime-related and demographic statistics for 47 US states in
* 1960. The data were collected from the FBI's Uniform Crime Report and other
* government agencies to determine how the variable crime rate depends on the
* other variables measured in the study.
* Number of cases: 47
*
* CrimeRat: Crime rate: # of offenses reported to police per million population
* MaleTeen: The number of males of age 14-24 per 1000 population
* South : Indicator variable for Southern states (0 = No, 1 = Yes)
* Educ : Mean # of years of schooling for persons of age 25 or older
* Police60: 1960 per capita expenditure on police by state and local government
* Police59: 1959 per capita expenditure on police by state and local government
* Labor : Labor force participation rate per 1000 civilian urban males age 14-24
* Males : The number of males per 1000 females
* Pop : State population size in hundred thousands
* NonWhite: The number of non-whites per 1000 population
* Unemp1 : Unemployment rate of urban males per 1000 of age 14-24
* Unemp2 : Unemployment rate of urban males per 1000 of age 35-39
* Median : Median value of transferable goods and assets or family income in tens of $
* BelowMed: The number of families per 1000 earning below 1/2 the median income
*/
DROP TABLE IF EXISTS crime;
CREATE TABLE crime (
id SERIAL NOT NULL,
CrimeRat DOUBLE PRECISION,
MaleTeen INTEGER,
South SMALLINT,
Educ DOUBLE PRECISION,
Police60 INTEGER,
Police59 INTEGER,
Labor INTEGER,
Males INTEGER,
Pop INTEGER,
NonWhite INTEGER,
Unemp1 INTEGER,
Unemp2 INTEGER,
Median INTEGER,
BelowMed INTEGER
)m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)');
INSERT INTO crime(
CrimeRat, MaleTeen, South, Educ, Police60, Police59, Labor, Males, Pop,
NonWhite, Unemp1, Unemp2, Median, BelowMed
) VALUES
(79.1, 151, 1, 9.1, 58, 56, 510, 950, 33, 301, 108, 41, 394, 261),
(163.5, 143, 0, 11.3, 103, 95, 583, 1012, 13, 102, 96, 36, 557, 194),
(57.8, 142, 1, 8.9, 45, 44, 533, 969, 18, 219, 94, 33, 318, 250),
(196.9, 136, 0, 12.1, 149, 141, 577, 994, 157, 80, 102, 39, 673, 167),
(123.4, 141, 0, 12.1, 109, 101, 591, 985, 18, 30, 91, 20, 578, 174),
(68.2, 121, 0, 11.0, 118, 115, 547, 964, 25, 44, 84, 29, 689, 126),
(96.3, 127, 1, 11.1, 82, 79, 519, 982, 4, 139, 97, 38, 620, 168),
(155.5, 131, 1, 10.9, 115, 109, 542, 969, 50, 179, 79, 35, 472, 206),
(85.6, 157, 1, 9.0, 65, 62, 553, 955, 39, 286, 81, 28, 421, 239),
(70.5, 140, 0, 11.8, 71, 68, 632, 1029, 7, 15, 100, 24, 526, 174),
(167.4, 124, 0, 10.5, 121, 116, 580, 966, 101, 106, 77, 35, 657, 170),
(84.9, 134, 0, 10.8, 75, 71, 595, 972, 47, 59, 83, 31, 580, 172),
(51.1, 128, 0, 11.3, 67, 60, 624, 972, 28, 10, 77, 25, 507, 206),
(66.4, 135, 0, 11.7, 62, 61, 595, 986, 22, 46, 77, 27, 529, 190),
(79.8, 152, 1, 8.7, 57, 53, 530, 986, 30, 72, 92, 43, 405, 264),
(94.6, 142, 1, 8.8, 81, 77, 497, 956, 33, 321, 116, 47, 427, 247),
(53.9, 143, 0, 11.0, 66, 63, 537, 977, 10, 6, 114, 35, 487, 166),
(92.9, 135, 1, 10.4, 123, 115, 537, 978, 31, 170, 89, 34, 631, 165),
(75.0, 130, 0, 11.6, 128, 128, 536, 934, 51, 24, 78, 34, 627, 135),
(122.5, 125, 0, 10.8, 113, 105, 567, 985, 78, 94, 130, 58, 626, 166),
(74.2, 126, 0, 10.8, 74, 67, 602, 984, 34, 12, 102, 33, 557, 195),
(43.9, 157, 1, 8.9, 47, 44, 512, 962, 22, 423, 97, 34, 288, 276),
(121.6, 132, 0, 9.6, 87, 83, 564, 953, 43, 92, 83, 32, 513, 227),
(96.8, 131, 0, 11.6, 78, 73, 574, 1038, 7, 36, 142, 42, 540, 176),
(52.3, 130, 0, 11.6, 63, 57, 641, 984, 14, 26, 70, 21, 486, 196),
(199.3, 131, 0, 12.1, 160, 143, 631, 1071, 3, 77, 102, 41, 674, 152),
(34.2, 135, 0, 10.9, 69, 71, 540, 965, 6, 4, 80, 22, 564, 139),
(121.6, 152, 0, 11.2, 82, 76, 571, 1018, 10, 79, 103, 28, 537, 215),
(104.3, 119, 0, 10.7, 166, 157, 521, 938, 168, 89, 92, 36, 637, 154),
(69.6, 166, 1, 8.9, 58, 54, 521, 973, 46, 254, 72, 26, 396, 237),
(37.3, 140, 0, 9.3, 55, 54, 535, 1045, 6, 20, 135, 40, 453, 200),
(75.4, 125, 0, 10.9, 90, 81, 586, 964, 97, 82, 105, 43, 617, 163),
(107.2, 147, 1, 10.4, 63, 64, 560, 972, 23, 95, 76, 24, 462, 233),
(92.3, 126, 0, 11.8, 97, 97, 542, 990, 18, 21, 102, 35, 589, 166),
(65.3, 123, 0, 10.2, 97, 87, 526, 948, 113, 76, 124, 50, 572, 158),
(127.2, 150, 0, 10.0, 109, 98, 531, 964, 9, 24, 87, 38, 559, 153),
(83.1, 177, 1, 8.7, 58, 56, 638, 974, 24, 349, 76, 28, 382, 254),
(56.6, 133, 0, 10.4, 51, 47, 599, 1024, 7, 40, 99, 27, 425, 225),
(82.6, 149, 1, 8.8, 61, 54, 515, 953, 36, 165, 86, 35, 395, 251),
(115.1, 145, 1, 10.4, 82, 74, 560, 981, 96, 126, 88, 31, 488, 228),
(88.0, 148, 0, 12.2, 72, 66, 601, 998, 9, 19, 84, 20, 590, 144),
(54.2, 141, 0, 10.9, 56, 54, 523, 968, 4, 2, 107, 37, 489, 170),
(82.3, 162, 1, 9.9, 75, 70, 522, 996, 40, 208, 73, 27, 496, 224),
(103.0, 136, 0, 12.1, 95, 96, 574, 1012, 29, 36, 111, 37, 622, 162),
(45.5, 139, 1, 8.8, 46, 41, 480, 968, 19, 49, 135, 53, 457, 249),
(50.8, 126, 0, 10.4, 106, 97, 599, 989, 40, 24, 78, 25, 593, 171),
(84.9, 130, 0, 12.1, 90, 91, 623, 1049, 3, 22, 113, 40, 588, 160);
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'crime',
'temp_result',
'crimerat >= 110',
'array[1, maleteen, south, educ, police59]'
);
SELECT assert(
relative_error(coef, ARRAY[-17.70177, .0833837, -1.117091, .0229224, .0581834]) < 1e-5 AND
relative_error(log_likelihood, -18.606959) < 1e-5 AND
relative_error(std_err, ARRAY[9.495993, .0440353, 1.359616, .5594047, .0210049]) < 1e-5 AND
relative_error(z_stats, ARRAY[-1.864, 1.894, -0.822, 0.041, 2.770]) < 1e-3 AND
relative_error(p_values, ARRAY[0.062, 0.058, 0.411, 0.967, 0.006]) < 1e-3 AND
relative_error(odds_ratios, ARRAY[exp(-17.70177), 1.086959, .3272305, 1.023187, 1.059909]) < 1e-5 AND
relative_error(condition_no, sqrt(14513863)) < 1e-2,
'Logistic regression with IRLS optimizer (crime): Wrong results'
) FROM temp_result;
-- Neither the conjugate-gradient nor incremental-gradient-descent optimizers
-- perform reasonably here, so we do not test them.
/*
* The following example is taken from:
* http://www.ats.ucla.edu/stat/r/dae/logit.htm
* Predicting heart attack
*/
DROP TABLE IF EXISTS grad_school;
DROP TABLE IF EXISTS grad_school_summary;
CREATE TABLE grad_school (
id SERIAL NOT NULL,
admit INTEGER,
gre INTEGER,
gpa DOUBLE PRECISION,
rank INTEGER
)m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)');
COPY grad_school (admit, gre, gpa, rank) FROM STDIN;
0 380 3.61 3
1 660 3.67 3
1 800 4 1
1 640 3.19 4
0 520 2.93 4
1 760 3 2
1 560 2.98 1
0 400 3.08 2
1 540 3.39 3
0 700 3.92 2
0 800 4 4
0 440 3.22 1
1 760 4 1
0 700 3.08 2
1 700 4 1
0 480 3.44 3
0 780 3.87 4
0 360 2.56 3
0 800 3.75 2
1 540 3.81 1
0 500 3.17 3
1 660 3.63 2
0 600 2.82 4
0 680 3.19 4
1 760 3.35 2
1 800 3.66 1
1 620 3.61 1
1 520 3.74 4
1 780 3.22 2
0 520 3.29 1
0 540 3.78 4
0 760 3.35 3
0 600 3.4 3
1 800 4 3
0 360 3.14 1
0 400 3.05 2
0 580 3.25 1
0 520 2.9 3
1 500 3.13 2
1 520 2.68 3
0 560 2.42 2
1 580 3.32 2
1 600 3.15 2
0 500 3.31 3
0 700 2.94 2
1 460 3.45 3
1 580 3.46 2
0 500 2.97 4
0 440 2.48 4
0 400 3.35 3
0 640 3.86 3
0 440 3.13 4
0 740 3.37 4
1 680 3.27 2
0 660 3.34 3
1 740 4 3
0 560 3.19 3
0 380 2.94 3
0 400 3.65 2
0 600 2.82 4
1 620 3.18 2
0 560 3.32 4
0 640 3.67 3
1 680 3.85 3
0 580 4 3
0 600 3.59 2
0 740 3.62 4
0 620 3.3 1
0 580 3.69 1
0 800 3.73 1
0 640 4 3
0 300 2.92 4
0 480 3.39 4
0 580 4 2
0 720 3.45 4
0 720 4 3
0 560 3.36 3
1 800 4 3
0 540 3.12 1
1 620 4 1
0 700 2.9 4
0 620 3.07 2
0 500 2.71 2
0 380 2.91 4
1 500 3.6 3
0 520 2.98 2
0 600 3.32 2
0 600 3.48 2
0 700 3.28 1
1 660 4 2
0 700 3.83 2
1 720 3.64 1
0 800 3.9 2
0 580 2.93 2
1 660 3.44 2
0 660 3.33 2
0 640 3.52 4
0 480 3.57 2
0 700 2.88 2
0 400 3.31 3
0 340 3.15 3
0 580 3.57 3
0 380 3.33 4
0 540 3.94 3
1 660 3.95 2
1 740 2.97 2
1 700 3.56 1
0 480 3.13 2
0 400 2.93 3
0 480 3.45 2
0 680 3.08 4
0 420 3.41 4
0 360 3 3
0 600 3.22 1
0 720 3.84 3
0 620 3.99 3
1 440 3.45 2
0 700 3.72 2
1 800 3.7 1
0 340 2.92 3
1 520 3.74 2
1 480 2.67 2
0 520 2.85 3
0 500 2.98 3
0 720 3.88 3
0 540 3.38 4
1 600 3.54 1
0 740 3.74 4
0 540 3.19 2
0 460 3.15 4
1 620 3.17 2
0 640 2.79 2
0 580 3.4 2
0 500 3.08 3
0 560 2.95 2
0 500 3.57 3
0 560 3.33 4
0 700 4 3
0 620 3.4 2
1 600 3.58 1
0 640 3.93 2
1 700 3.52 4
0 620 3.94 4
0 580 3.4 3
0 580 3.4 4
0 380 3.43 3
0 480 3.4 2
0 560 2.71 3
1 480 2.91 1
0 740 3.31 1
1 800 3.74 1
0 400 3.38 2
1 640 3.94 2
0 580 3.46 3
0 620 3.69 3
1 580 2.86 4
0 560 2.52 2
1 480 3.58 1
0 660 3.49 2
0 700 3.82 3
0 600 3.13 2
0 640 3.5 2
1 700 3.56 2
0 520 2.73 2
0 580 3.3 2
0 700 4 1
0 440 3.24 4
0 720 3.77 3
0 500 4 3
0 600 3.62 3
0 400 3.51 3
0 540 2.81 3
0 680 3.48 3
1 800 3.43 2
0 500 3.53 4
1 620 3.37 2
0 520 2.62 2
1 620 3.23 3
0 620 3.33 3
0 300 3.01 3
0 620 3.78 3
0 500 3.88 4
0 700 4 2
1 540 3.84 2
0 500 2.79 4
0 800 3.6 2
0 560 3.61 3
0 580 2.88 2
0 560 3.07 2
0 500 3.35 2
1 640 2.94 2
0 800 3.54 3
0 640 3.76 3
0 380 3.59 4
1 600 3.47 2
0 560 3.59 2
0 660 3.07 3
1 400 3.23 4
0 600 3.63 3
0 580 3.77 4
0 800 3.31 3
1 580 3.2 2
1 700 4 1
0 420 3.92 4
1 600 3.89 1
1 780 3.8 3
0 740 3.54 1
1 640 3.63 1
0 540 3.16 3
0 580 3.5 2
0 740 3.34 4
0 580 3.02 2
0 460 2.87 2
0 640 3.38 3
1 600 3.56 2
1 660 2.91 3
0 340 2.9 1
1 460 3.64 1
0 460 2.98 1
1 560 3.59 2
0 540 3.28 3
0 680 3.99 3
1 480 3.02 1
0 800 3.47 3
0 800 2.9 2
1 720 3.5 3
0 620 3.58 2
0 540 3.02 4
0 480 3.43 2
1 720 3.42 2
0 580 3.29 4
0 600 3.28 3
0 380 3.38 2
0 420 2.67 3
1 800 3.53 1
0 620 3.05 2
1 660 3.49 2
0 480 4 2
0 500 2.86 4
0 700 3.45 3
0 440 2.76 2
1 520 3.81 1
1 680 2.96 3
0 620 3.22 2
0 540 3.04 1
0 800 3.91 3
0 680 3.34 2
0 440 3.17 2
0 680 3.64 3
0 640 3.73 3
0 660 3.31 4
0 620 3.21 4
1 520 4 2
1 540 3.55 4
1 740 3.52 4
0 640 3.35 3
1 520 3.3 2
1 620 3.95 3
0 520 3.51 2
0 640 3.81 2
0 680 3.11 2
0 440 3.15 2
1 520 3.19 3
1 620 3.95 3
1 520 3.9 3
0 380 3.34 3
0 560 3.24 4
1 600 3.64 3
1 680 3.46 2
0 500 2.81 3
1 640 3.95 2
0 540 3.33 3
1 680 3.67 2
0 660 3.32 1
0 520 3.12 2
1 600 2.98 2
0 460 3.77 3
1 580 3.58 1
1 680 3 4
1 660 3.14 2
0 660 3.94 2
0 360 3.27 3
0 660 3.45 4
0 520 3.1 4
1 440 3.39 2
0 600 3.31 4
1 800 3.22 1
1 660 3.7 4
0 800 3.15 4
0 420 2.26 4
1 620 3.45 2
0 800 2.78 2
0 680 3.7 2
0 800 3.97 1
0 480 2.55 1
0 520 3.25 3
0 560 3.16 1
0 460 3.07 2
0 540 3.5 2
0 720 3.4 3
0 640 3.3 2
1 660 3.6 3
1 400 3.15 2
1 680 3.98 2
0 220 2.83 3
0 580 3.46 4
1 540 3.17 1
0 580 3.51 2
0 540 3.13 2
0 440 2.98 3
0 560 4 3
0 660 3.67 2
0 660 3.77 3
1 520 3.65 4
0 540 3.46 4
1 300 2.84 2
1 340 3 2
1 780 3.63 4
1 480 3.71 4
0 540 3.28 1
0 460 3.14 3
0 460 3.58 2
0 500 3.01 4
0 420 2.69 2
0 520 2.7 3
0 680 3.9 1
0 680 3.31 2
1 560 3.48 2
0 580 3.34 2
0 500 2.93 4
0 740 4 3
0 660 3.59 3
0 420 2.96 1
0 560 3.43 3
1 460 3.64 3
1 620 3.71 1
0 520 3.15 3
0 620 3.09 4
0 540 3.2 1
1 660 3.47 3
0 500 3.23 4
1 560 2.65 3
0 500 3.95 4
0 580 3.06 2
0 520 3.35 3
0 500 3.03 3
0 600 3.35 2
0 580 3.8 2
0 400 3.36 2
0 620 2.85 2
1 780 4 2
0 620 3.43 3
1 580 3.12 3
0 700 3.52 2
1 540 3.78 2
1 760 2.81 1
0 700 3.27 2
0 720 3.31 1
1 560 3.69 3
0 720 3.94 3
1 520 4 1
1 540 3.49 1
0 680 3.14 2
0 460 3.44 2
1 560 3.36 1
0 480 2.78 3
0 460 2.93 3
0 620 3.63 3
0 580 4 1
0 800 3.89 2
1 540 3.77 2
1 680 3.76 3
1 680 2.42 1
1 620 3.37 1
0 560 3.78 2
0 560 3.49 4
0 620 3.63 2
1 800 4 2
0 640 3.12 3
0 540 2.7 2
0 700 3.65 2
1 540 3.49 2
0 540 3.51 2
0 660 4 1
1 480 2.62 2
0 420 3.02 1
1 740 3.86 2
0 580 3.36 2
0 640 3.17 2
0 640 3.51 2
1 800 3.05 2
1 660 3.88 2
1 600 3.38 3
1 620 3.75 2
1 460 3.99 3
0 620 4 2
0 560 3.04 3
0 460 2.63 2
0 700 3.65 2
0 600 3.89 3
\.
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'grad_school',
'temp_result',
'admit',
'ARRAY[1, gre, gpa, (rank = 2)::INT::FLOAT8, (rank = 3)::INT::FLOAT8, (rank = 4)::INT::FLOAT8]'
);
SELECT assert(
relative_error(coef, ARRAY[-3.989979, 0.002264, 0.804038, -0.675443, -1.340204, -1.551464]) < 1e-5 AND
relative_error(log_likelihood, -229.2587) < 1e-5 AND
relative_error(std_err, ARRAY[1.139951, 0.001094, 0.331819, 0.316490, 0.345306, 0.417832]) < 1e-5 AND
relative_error(z_stats, ARRAY[-3.500, 2.070, 2.423, -2.134, -3.881, -3.713]) < 1e-3 AND
relative_error(p_values, ARRAY[0.000465, 0.038465, 0.015388, 0.032829, 0.000104, 0.000205]) < 1e-3 AND
relative_error(odds_ratios, ARRAY[0.0185001, 1.0022670, 2.2345448, 0.5089310, 0.2617923, 0.2119375]) < 1e-5 AND
relative_error(condition_no, 6379.028) < 1e-2,
'Logistic regression with IRLS optimizer (grad_school): Wrong results'
) FROM temp_result;
select logregr_predict(
coef,
ARRAY[1, gre, gpa, (rank = 2)::INT::FLOAT8, (rank = 3)::INT::FLOAT8, (rank = 4)::INT::FLOAT8])
FROM temp_result, grad_school;
select logregr_predict_prob(
coef,
ARRAY[1, gre, gpa, (rank = 2)::INT::FLOAT8, (rank = 3)::INT::FLOAT8, (rank = 4)::INT::FLOAT8])
FROM temp_result, grad_school;
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'grad_school',
'temp_result',
'admit',
'ARRAY[1, gre, gpa, (rank = 2)::INT::FLOAT8, (rank = 3)::INT::FLOAT8, (rank = 4)::INT::FLOAT8]',
Null,
20,
'cg',
0
);
--- insert some NULL values for null handling testing
INSERT INTO grad_school (admit, gre, gpa, rank) VALUES
(NULL, NULL, 3, 4),
(1, NULL, 3, 5);
drop table if exists temp_result;
drop table if exists temp_result_summary;
select logregr_train(
'grad_school',
'temp_result',
'admit',
'ARRAY[1, gre, gpa]',
'rank');
-- Even though we were far more generous for the conjugate-gradient optimizer,
-- but it still generates failures intermittently. We comment it out here
-- to acknowledge CG may generate un-converged results.
-- SELECT
-- assert(relative_error(coef, ARRAY[-3.989979, 0.002264, 0.804038, -0.675443, -1.340204, -1.551464]) < 0.1, 'Logistic regression with CG optimizer (grad_school): Wrong coef'),
-- assert(relative_error(log_likelihood, -229.2587) < 0.1, 'Logistic regression with CG optimizer (grad_school): Wrong log_likelihood'),
-- assert(relative_error(std_err, ARRAY[1.139951, 0.001094, 0.331819, 0.316490, 0.345306, 0.417832]) < 0.1, 'Logistic regression with CG optimizer (grad_school): Wrong std_err'),
-- assert(relative_error(z_stats, ARRAY[-3.500, 2.070, 2.423, -2.134, -3.881, -3.713]) < 0.1, 'Logistic regression with CG optimizer (grad_school): Wrong z-stats'),
-- assert(relative_error(p_values, ARRAY[0.000465, 0.038465, 0.015388, 0.032829, 0.000104, 0.000205]) < 1, 'Logistic regression with CG optimizer (grad_school): Wrong p-values'),
-- assert(relative_error(odds_ratios, ARRAY[0.0185001, 1.0022670, 2.2345448, 0.5089310, 0.2617923, 0.2119375]) < 0.1, 'Logistic regression with CG optimizer (grad_school): Wrong odds_ratios')
-- FROM temp_result;
-- IGD essentially does not work for this case, so we are not testing it