| /* ----------------------------------------------------------------------------- |
| * Test Linear Support Vector Machine |
| * -------------------------------------------------------------------------- */ |
| |
| CREATE OR REPLACE FUNCTION __svm_target_cl_func(ind float8[]) |
| RETURNS float8 AS $$ |
| BEGIN |
| IF (ind[1] > 0 AND ind[2] < 0) THEN RETURN 1; END IF; |
| RETURN -1; |
| END |
| $$ LANGUAGE plpgsql; |
| |
| CREATE OR REPLACE FUNCTION __svr_target_cl_func(ind float8[]) |
| RETURNS float8 AS $$ |
| BEGIN |
| RETURN 1*ind[1] + 2*ind[2]; |
| END |
| $$ LANGUAGE plpgsql; |
| |
| CREATE OR REPLACE FUNCTION __svm_random_ind(d INT) |
| RETURNS float8[] AS $$ |
| DECLARE |
| ret float8[]; |
| BEGIN |
| FOR i IN 1..(d-1) LOOP |
| ret[i] = RANDOM() * 40 - 20; |
| END LOOP; |
| IF (RANDOM() > 0.5) THEN |
| ret[d] = 10; |
| ELSE |
| ret[d] = -10; |
| END IF; |
| RETURN ret; |
| END |
| $$ LANGUAGE plpgsql; |
| |
| CREATE OR REPLACE FUNCTION svm_generate_cls_data( |
| output_table text, num int, dim int) |
| RETURNS VOID AS $$ |
| DECLARE |
| temp_table text; |
| BEGIN |
| temp_table := 'madlib_temp_' || output_table; |
| EXECUTE ' |
| CREATE TABLE ' || temp_table || ' AS |
| SELECT |
| subq.val AS id, |
| __svm_random_ind(' || dim || ') AS ind |
| FROM |
| (SELECT generate_series(1, ' || num || ') AS val) subq'; |
| EXECUTE ' |
| CREATE TABLE ' || output_table || ' AS |
| SELECT id, ind, __svm_target_cl_func(ind) AS label |
| FROM ' || temp_table; |
| END |
| $$ LANGUAGE plpgsql; |
| |
| CREATE OR REPLACE FUNCTION svr_generate_cls_data( |
| output_table text, num int, dim int) |
| RETURNS VOID AS $$ |
| DECLARE |
| temp_table text; |
| BEGIN |
| temp_table := 'madlib_temp_' || output_table; |
| EXECUTE ' |
| CREATE TABLE ' || temp_table || ' AS |
| SELECT |
| subq.val AS id, |
| __svm_random_ind(' || dim || ') AS ind |
| FROM |
| (SELECT generate_series(1, ' || num || ') AS val) subq'; |
| EXECUTE ' |
| CREATE TABLE ' || output_table || ' AS |
| SELECT id, ind, __svr_target_cl_func(ind) AS label |
| FROM ' || temp_table; |
| END |
| $$ LANGUAGE plpgsql; |
| |
| SELECT svm_generate_cls_data('svm_train_data', 1000, 4); |
| SELECT svm_generate_cls_data('svm_test_data', 1000, 4); |
| SELECT svr_generate_cls_data('svr_train_data', 1000, 4); |
| SELECT svr_generate_cls_data('svr_test_data', 1000, 4); |
| |
| -- check the default values |
| SELECT svm_regression( |
| 'svr_train_data', |
| 'svr_model', |
| 'label', |
| 'ind'); |
| \x on |
| SELECT * FROM svr_model; |
| SELECT * FROM svr_model_summary; |
| \x off |
| SELECT |
| assert( |
| norm1(coef) < 4, |
| 'optimal coef should be close to [1, 2, 0, 0]!') |
| FROM svr_model; |
| |
| -- check the use of l1 norm |
| SELECT svm_regression( |
| 'svr_train_data', |
| 'svr_model2', |
| 'label', |
| 'ind', |
| NULL, |
| NULL, |
| NULL, |
| 'init_stepsize=0.01, max_iter=50, lambda=2, norm=l2, epsilon=0.01', |
| false); |
| SELECT svm_predict('svr_model2', 'svr_train_data', 'id', 'svr_test_result'); |
| \x on |
| SELECT * FROM svr_model2; |
| \x off |
| SELECT |
| assert( |
| avg(subq.err) < 0.1, |
| 'prediction error is too large!') |
| FROM |
| ( |
| SELECT |
| train.id, |
| abs(train.label - test.prediction) AS err |
| FROM svr_train_data AS train, svr_test_result AS test |
| WHERE train.id = test.id |
| ) AS subq; |
| |
| -- Example usage for LINEAR classification, replace the above by |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'lclss', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, --grouping_col |
| 'max_iter=10, tolerance=0' --optim_params |
| ); |
| SELECT * FROM lclss; |
| SELECT * FROM lclss_summary; |
| |
| SELECT svm_predict('lclss', 'svm_test_data', 'id', 'svm_test_predict'); |
| |
| -- checking correctness with pre-conditioning |
| CREATE TABLE svm_normalized AS |
| SELECT |
| id, |
| array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind, |
| label |
| FROM svm_train_data, |
| ( |
| SELECT ARRAY[avg(ind[1]),avg(ind[2]), |
| avg(ind[3]),avg(ind[4])] AS ind_avg |
| FROM svm_train_data |
| ) AS svm_ind_avg, |
| ( |
| SELECT ARRAY[stddev(ind[1]),stddev(ind[2]), |
| stddev(ind[3]),stddev(ind[4])] AS ind_stddev |
| FROM svm_train_data |
| ) AS svm_ind_stddev |
| ORDER BY random(); |
| |
| CREATE TABLE svm_test_normalized AS |
| SELECT |
| id, |
| array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind, |
| label |
| FROM svm_test_data, |
| ( |
| SELECT ARRAY[avg(ind[1]),avg(ind[2]), |
| avg(ind[3]),avg(ind[4])] AS ind_avg |
| FROM svm_test_data |
| ) AS svm_test_ind_avg, |
| ( |
| SELECT ARRAY[stddev(ind[1]),stddev(ind[2]), |
| stddev(ind[3]),stddev(ind[4])] AS ind_stddev |
| FROM svm_test_data |
| ) AS svm_test_ind_stddev; |
| |
| ---------------------------------------------------------------- |
| -- serial |
| -- learning |
| SELECT svm_classification( |
| 'svm_normalized', |
| 'svm_model', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, -- grouping_col |
| 'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=0', |
| false -- verbose |
| ); |
| \x on |
| SELECT * FROM svm_model; |
| SELECT * FROM svm_model_summary; |
| \x off |
| |
| -- l2 |
| SELECT svm_classification( |
| 'svm_normalized', |
| 'svm_model_small_norm2', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, --grouping_col |
| 'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=1' |
| ); |
| \x on |
| SELECT * FROM svm_model_small_norm2; |
| \x off |
| |
| SELECT |
| assert( |
| norm2(l2.coef) < norm2(noreg.coef) OR |
| ( |
| (norm2(l2.coef)-norm2(noreg.coef))/norm2(noreg.coef) < 0.1 AND |
| l2.loss < noreg.loss |
| ), |
| 'l2 regularization should produce coef with smaller l2 norm!') |
| FROM svm_model AS noreg, svm_model_small_norm2 AS l2; |
| |
| |
| -- l1 makes sprase models |
| SELECT svm_classification( |
| 'svm_normalized', |
| 'svm_model_very_sparse', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, --grouping_col |
| 'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=1, norm=L1' |
| ); |
| \x on |
| SELECT * FROM svm_model_very_sparse; |
| \x off |
| SELECT |
| assert( |
| count(*) > 0, |
| 'The model is supposed to be sparse with reg=1') |
| FROM |
| ( |
| SELECT unnest(coef) AS w_i FROM svm_model_very_sparse |
| ) subq |
| WHERE w_i != 0; |
| |
| -- predicting |
| SELECT svm_predict('svm_model','svm_test_normalized', 'id', 'svm_test_predict2'); |
| |
| -- calculating accuracy |
| -- the accuracy is not guaranteed to be high because the stepsize & decay_factor |
| -- depend on the actual number of segments |
| SELECT |
| count(*) AS misclassification_count |
| FROM svm_test_predict2 NATURAL JOIN svm_test_normalized |
| WHERE prediction <> label; |
| |
| ---------------------------------------------------------------- |
| -- decay factor non-zero |
| -- learning |
| SELECT svm_classification( |
| 'svm_normalized', |
| 'svm_model_decay_factor_non_zero', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, --grouping_col |
| 'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda={0.001}', |
| false -- verbose |
| ); |
| SELECT norm_of_gradient FROM svm_model_decay_factor_non_zero; |
| |
| -- predicting |
| CREATE TABLE svm_test_predict_decay_factor_nonzero AS |
| SELECT |
| svm_test_normalized.id, |
| CASE WHEN array_dot(coef, ind) >= 0 THEN 1 ELSE -1 END AS prediction, |
| label |
| FROM svm_test_normalized, svm_model_decay_factor_non_zero; |
| |
| -- stats for info |
| SELECT count(*) AS misclassification_count |
| FROM svm_test_predict_decay_factor_nonzero |
| WHERE prediction <> label; |
| |
| |
| ----------------------------------------------------------------- |
| -- labels that are not just 1,-1 |
| CREATE TABLE svm_normalized_fancy_label AS |
| SELECT |
| id, |
| array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind, |
| CASE when label = 1 THEN 'YES' |
| ELSE 'NO' |
| END AS label, |
| (id % 4) AS gid |
| FROM svm_train_data, |
| ( |
| SELECT ARRAY[avg(ind[1]),avg(ind[2]), |
| avg(ind[3]),avg(ind[4])] AS ind_avg |
| FROM svm_train_data |
| ) AS svm_ind_avg, |
| ( |
| SELECT ARRAY[stddev(ind[1]),stddev(ind[2]), |
| stddev(ind[3]),stddev(ind[4])] AS ind_stddev |
| FROM svm_train_data |
| ) AS svm_ind_stddev |
| ORDER BY random(); |
| INSERT INTO svm_normalized_fancy_label VALUES |
| (1001, ARRAY[NULL,1,1,1,1]::float8[], 'YES', 1001 % 4), |
| (1002, ARRAY[5,1,1,1,1]::float8[], NULL, 1002 % 4), |
| (1003, ARRAY[5,1,NULL,1,1]::float8[], NULL, 1003 % 4); |
| |
| CREATE TABLE svm_test_normalized_fancy_label AS |
| SELECT |
| id, |
| array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind, |
| CASE when label = 1 THEN 'YES' |
| ELSE 'NO' |
| END AS label, |
| (id % 4) as gid |
| FROM svm_test_data, |
| ( |
| SELECT ARRAY[avg(ind[1]),avg(ind[2]), |
| avg(ind[3]),avg(ind[4])] AS ind_avg |
| FROM svm_test_data |
| ) AS svm_test_ind_avg, |
| ( |
| SELECT ARRAY[stddev(ind[1]),stddev(ind[2]), |
| stddev(ind[3]),stddev(ind[4])] AS ind_stddev |
| FROM svm_test_data |
| ) AS svm_test_ind_stddev; |
| INSERT INTO svm_test_normalized_fancy_label VALUES |
| (1001, ARRAY[NULL,1,1,1,1]::float8[], 'YES', 1001 % 4); |
| |
| -- training |
| SELECT svm_classification( |
| 'svm_normalized_fancy_label', |
| 'svm_model_fancy_label', |
| 'label', |
| 'ind', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| 'gid', --grouping_col |
| 'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda=0.001', |
| TRUE -- verbose |
| ); |
| \x on |
| SELECT * FROM svm_model_fancy_label; |
| SELECT * FROM svm_model_fancy_label_summary; |
| \x off |
| SELECT assert(count(*)=4, '4 group exist') FROM svm_model_fancy_label; |
| -- SELECT assert(total_rows_skipped=3, 'total_rows_skipped is wrong') |
| -- FROM svm_model_fancy_label_summary; |
| |
| SELECT svm_predict('svm_model_fancy_label', 'svm_test_normalized_fancy_label', 'id', 'svm_test_fancy_label'); |
| SELECT o.id, label, prediction, o.gid FROM svm_test_predict p, svm_test_normalized_fancy_label o where o.id = p.id; |
| |
| -- calculating accuracy |
| -- the accuracy is not guaranteed to be high because the stepsize & decay_factor |
| -- depend on the actual number of segments |
| -- SELECT |
| -- count(*) AS misclassification_count |
| -- FROM svm_test_predict NATURAL JOIN svm_test_normalized_fancy_label |
| -- WHERE prediction <> label; |
| |
| -- tests for depend varname being expression |
| SELECT svm_classification( |
| 'svm_normalized', |
| 'svm_model_expression', |
| 'label>(ind[2]+ind[4])', |
| 'ARRAY[ind[1],ind[3],ind[5]]', |
| NULL, -- kernel_func |
| NULL, -- kernel_pararms |
| NULL, --grouping_col |
| 'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda=0.001', |
| false -- verbose |
| ); |
| \x on |
| SELECT * FROM svm_model_expression; |
| SELECT * FROM svm_model_expression_summary; |
| \x off |
| |
| |
| SELECT svm_one_class( |
| 'svm_normalized', |
| 'svm_model_expression1', |
| 'ind', |
| 'gaussian' |
| ); |
| \x on |
| SELECT * FROM svm_model_expression1; |
| SELECT * FROM svm_model_expression1_summary; |
| \x off |
| |
| CREATE TABLE abalone_train_small_tmp ( |
| sex TEXT, |
| id SERIAL NOT NULL, |
| length DOUBLE PRECISION, |
| diameter DOUBLE PRECISION, |
| height DOUBLE PRECISION, |
| whole DOUBLE PRECISION, |
| shucked DOUBLE PRECISION, |
| viscera DOUBLE PRECISION, |
| shell DOUBLE PRECISION, |
| rings INTEGER); |
| |
| INSERT INTO abalone_train_small_tmp(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES |
| (1040,'F',0.66,0.475,0.18,1.3695,0.641,0.294,0.335,6), |
| (3160,'F',0.34,0.255,0.085,0.204,0.097,0.021,0.05,6), |
| (3984,'F',0.585,0.45,0.125,0.874,0.3545,0.2075,0.225,6), |
| (861,'F',0.595,0.475,0.16,1.1405,0.547,0.231,0.271,6), |
| (932,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6), |
| (1585,'F',0.515,0.375,0.11,0.6065,0.3005,0.131,0.15,6), |
| (3187,'F',0.47,0.36,0.11,0.4965,0.237,0.127,0.13,6), |
| (3202,'F',0.385,0.3,0.1,0.2725,0.1115,0.057,0.08,6), |
| (949,'F',0.475,0.36,0.12,0.5915,0.3245,0.11,0.127,6), |
| (2582,'F',0.53,0.42,0.17,0.828,0.41,0.208,0.1505,6), |
| (2551,'I',0.28,0.22,0.08,0.1315,0.066,0.024,0.03,5), |
| (1246,'I',0.385,0.28,0.09,0.228,0.1025,0.042,0.0655,5), |
| (819,'I',0.35,0.25,0.07,0.18,0.0655,0.048,0.054,6), |
| (297,'I',0.275,0.205,0.075,0.1105,0.045,0.0285,0.035,6), |
| (3630,'I',0.27,0.205,0.05,0.084,0.03,0.0185,0.029,6), |
| (2196,'I',0.26,0.215,0.08,0.099,0.037,0.0255,0.045,5), |
| (2343,'I',0.255,0.185,0.07,0.075,0.028,0.018,0.025,6), |
| (49,'I',0.325,0.245,0.07,0.161,0.0755,0.0255,0.045,6), |
| (2185,'I',0.32,0.235,0.08,0.1485,0.064,0.031,0.045,6), |
| (2154,'I',0.28,0.2,0.075,0.1225,0.0545,0.0115,0.035,5), |
| (1996,'I',0.32,0.24,0.07,0.133,0.0585,0.0255,0.041,6), |
| (126,'I',0.27,0.195,0.06,0.073,0.0285,0.0235,0.03,5), |
| (1227,'I',0.35,0.27,0.075,0.215,0.1,0.036,0.065,6), |
| (3969,'I',0.375,0.29,0.095,0.2875,0.123,0.0605,0.08,6), |
| (2505,'I',0.31,0.24,0.105,0.2885,0.118,0.065,0.083,6), |
| (2039,'I',0.28,0.215,0.08,0.132,0.072,0.022,0.033,5), |
| (829,'I',0.41,0.325,0.1,0.394,0.208,0.0655,0.106,6), |
| (3197,'I',0.325,0.245,0.075,0.1495,0.0605,0.033,0.045,5), |
| (1447,'I',0.44,0.34,0.105,0.369,0.164,0.08,0.1015,5), |
| (2821,'I',0.375,0.285,0.09,0.2545,0.119,0.0595,0.0675,6), |
| (1828,'I',0.34,0.26,0.085,0.1885,0.0815,0.0335,0.06,6), |
| (2002,'I',0.36,0.27,0.085,0.2185,0.1065,0.038,0.062,6), |
| (785,'I',0.215,0.155,0.06,0.0525,0.021,0.0165,0.015,5), |
| (2199,'I',0.27,0.19,0.08,0.081,0.0265,0.0195,0.03,6), |
| (3527,'I',0.335,0.26,0.085,0.192,0.097,0.03,0.054,6), |
| (466,'I',0.175,0.125,0.05,0.0235,0.008,0.0035,0.008,5), |
| (425,'I',0.26,0.2,0.07,0.092,0.037,0.02,0.03,6), |
| (1825,'I',0.185,0.135,0.04,0.027,0.0105,0.0055,0.009,5), |
| (3815,'I',0.38,0.275,0.095,0.2425,0.106,0.0485,0.21,6), |
| (2503,'I',0.285,0.21,0.07,0.109,0.044,0.0265,0.033,5), |
| (3998,'I',0.36,0.27,0.09,0.2075,0.098,0.039,0.062,6), |
| (333,'I',0.3,0.22,0.08,0.121,0.0475,0.042,0.035,5), |
| (1837,'I',0.415,0.31,0.09,0.2815,0.1245,0.0615,0.085,6), |
| (2813,'I',0.24,0.17,0.05,0.0545,0.0205,0.016,0.0155,5), |
| (930,'I',0.44,0.345,0.13,0.4495,0.209,0.0835,0.134,6), |
| (1436,'I',0.385,0.3,0.09,0.247,0.1225,0.044,0.0675,5), |
| (3972,'I',0.4,0.295,0.095,0.252,0.1105,0.0575,0.066,6), |
| (1433,'I',0.365,0.255,0.08,0.1985,0.0785,0.0345,0.053,5), |
| (1252,'I',0.405,0.285,0.09,0.2645,0.1265,0.0505,0.075,6), |
| (3439,'I',0.43,0.335,0.105,0.378,0.188,0.0785,0.09,6), |
| (1250,'I',0.395,0.27,0.1,0.2985,0.1445,0.061,0.082,5), |
| (2865,'I',0.31,0.23,0.07,0.1245,0.0505,0.0265,0.038,6), |
| (3411,'I',0.415,0.31,0.105,0.3595,0.167,0.083,0.0915,6), |
| (1539,'I',0.355,0.27,0.075,0.1775,0.079,0.0315,0.054,6), |
| (1990,'I',0.28,0.21,0.075,0.1195,0.053,0.0265,0.03,6), |
| (1771,'I',0.455,0.335,0.105,0.422,0.229,0.0865,0.1,6), |
| (2291,'I',0.325,0.27,0.1,0.185,0.08,0.0435,0.065,6), |
| (3381,'I',0.19,0.13,0.045,0.0265,0.009,0.005,0.009,5), |
| (1545,'I',0.37,0.27,0.095,0.2175,0.097,0.046,0.065,6), |
| (652,'I',0.335,0.245,0.09,0.1665,0.0595,0.04,0.06,6), |
| (3434,'I',0.365,0.27,0.105,0.2155,0.0915,0.0475,0.063,6), |
| (2004,'I',0.375,0.28,0.08,0.226,0.105,0.047,0.065,6), |
| (2000,'I',0.35,0.25,0.07,0.1605,0.0715,0.0335,0.046,6), |
| (3946,'I',0.235,0.175,0.065,0.0615,0.0205,0.02,0.019,6), |
| (177,'I',0.315,0.21,0.06,0.125,0.06,0.0375,0.035,5), |
| (920,'I',0.41,0.31,0.09,0.3335,0.1635,0.061,0.091,6), |
| (3437,'I',0.38,0.275,0.095,0.2505,0.0945,0.0655,0.075,6), |
| (2630,'I',0.33,0.24,0.075,0.163,0.0745,0.033,0.048,6), |
| (1092,'I',0.45,0.33,0.11,0.3685,0.16,0.0885,0.102,6), |
| (3476,'I',0.4,0.315,0.085,0.2675,0.116,0.0585,0.0765,6), |
| (3526,'I',0.33,0.23,0.085,0.1695,0.079,0.026,0.0505,6), |
| (1534,'I',0.295,0.215,0.07,0.121,0.047,0.0155,0.0405,6), |
| (921,'I',0.415,0.33,0.09,0.3595,0.17,0.081,0.09,6), |
| (2206,'I',0.275,0.22,0.08,0.1365,0.0565,0.0285,0.042,6), |
| (1218,'I',0.315,0.23,0.08,0.1375,0.0545,0.031,0.0445,5), |
| (1998,'I',0.335,0.25,0.08,0.1695,0.0695,0.044,0.0495,6), |
| (2455,'I',0.275,0.2,0.065,0.092,0.0385,0.0235,0.027,5), |
| (2548,'I',0.23,0.18,0.05,0.064,0.0215,0.0135,0.02,5), |
| (3996,'I',0.245,0.175,0.055,0.0785,0.04,0.018,0.02,5), |
| (3408,'I',0.35,0.265,0.08,0.192,0.081,0.0465,0.053,6), |
| (3907,'M',0.245,0.18,0.065,0.0635,0.0245,0.0135,0.02,4), |
| (3850,'M',0.385,0.3,0.115,0.3435,0.1645,0.085,0.1025,6), |
| (124,'M',0.37,0.265,0.075,0.214,0.09,0.051,0.07,6), |
| (2583,'M',0.53,0.41,0.14,0.681,0.3095,0.1415,0.1835,6), |
| (526,'M',0.175,0.125,0.04,0.024,0.0095,0.006,0.005,4), |
| (2184,'M',0.495,0.4,0.155,0.8085,0.2345,0.1155,0.35,6), |
| (2132,'M',0.32,0.24,0.08,0.18,0.08,0.0385,0.055,6), |
| (651,'M',0.255,0.18,0.065,0.079,0.034,0.014,0.025,5), |
| (612,'M',0.195,0.145,0.05,0.032,0.01,0.008,0.012,4), |
| (958,'M',0.5,0.39,0.135,0.6595,0.3145,0.1535,0.1565,6), |
| (3174,'M',0.35,0.265,0.09,0.2265,0.0995,0.0575,0.065,6), |
| (265,'M',0.27,0.2,0.08,0.1205,0.0465,0.028,0.04,6), |
| (519,'M',0.325,0.23,0.09,0.147,0.06,0.034,0.045,4), |
| (2382,'M',0.155,0.115,0.025,0.024,0.009,0.005,0.0075,5), |
| (698,'M',0.28,0.205,0.1,0.1165,0.0545,0.0285,0.03,5), |
| (2381,'M',0.175,0.135,0.04,0.0305,0.011,0.0075,0.01,5), |
| (516,'M',0.27,0.195,0.08,0.1,0.0385,0.0195,0.03,6), |
| (831,'M',0.415,0.305,0.1,0.325,0.156,0.0505,0.091,6), |
| (3359,'M',0.285,0.215,0.075,0.106,0.0415,0.023,0.035,5); |
| |
| CREATE TABLE abalone_train_small AS ( |
| SELECT * FROM abalone_train_small_tmp |
| ); |
| |
| -- create epsilon input table |
| CREATE TABLE abalone_eps ( |
| sex TEXT, |
| epsilon DOUBLE PRECISION); |
| INSERT INTO abalone_eps(sex, epsilon) VALUES |
| ('I', 0.2), |
| ('M', 0.05); |
| |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm5', |
| 'label', |
| 'ind', |
| NULL,NULL,NULL, |
| 'init_stepsize=0.01, max_iter=20, lambda=0.000002'); |
| SELECT svm_predict('m5','svm_test_data', 'id', 'svm_test_5'); |
| -- accuracy without cv |
| SELECT |
| count(*) AS misclassification_count |
| FROM svm_test_5 NATURAL JOIN svm_test_data |
| WHERE prediction <> label; |
| |
| -- SVM with kernels ----------------------------------------------------------- |
| -- verify guassian kernel mapping dimensions |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm6', |
| 'label', |
| 'ind', |
| 'gaussian', |
| 'n_components=3, fit_intercept=false', |
| NULL, |
| 'max_iter=2'); |
| SELECT svm_predict('m6','svm_test_data', 'id', 'svm_test_6'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 3, |
| 'The dimension of the coefficients must be equal to n_components (3)!') |
| FROM m6; |
| |
| -- verify gaussian kernel with grouping |
| -- verify partial string support in kernel specification |
| SELECT svm_regression( |
| 'abalone_train_small', |
| 'svr_mdl_m', |
| 'rings', |
| 'ARRAY[1,diameter,shell,shucked,length]', |
| 'gau', |
| 'n_components=10', |
| 'sex', |
| 'max_iter=2, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05', |
| false); |
| SELECT svm_predict('svr_mdl_m','abalone_train_small', 'id', 'svm_test_mdl_m'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 10, |
| 'The dimension of the coefficients must be equal to n_components (10)!') |
| FROM svr_mdl_m; |
| |
| -- verify guassian kernel with cross validation |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm7', |
| 'label', |
| 'ind', |
| 'gaussian', |
| 'n_components=3, fit_intercept=true', |
| NULL, |
| 'init_stepsize=[0.01, 0.1], max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5], validation_result=m7_cv'); |
| |
| SELECT * FROM m7_cv; |
| |
| SELECT svm_predict('m7','svm_test_data', 'id', 'svm_test_7'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 4, |
| 'The dimension of the coefficients must be equal to n_components + 1 (4)!') |
| FROM m7; |
| |
| -- verify guassian kernel with out-of-memory support |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm8', |
| 'label', |
| 'ind', |
| 'gaussian', |
| 'n_components=3, fit_in_memory=False', |
| NULL, |
| 'max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5]'); |
| |
| SELECT |
| assert( |
| array_upper(coef, 1) = 3, |
| 'The dimension of the coefficients must be equal to n_components (3)!') |
| FROM m8; |
| |
| CREATE TABLE kernel_data ( |
| index bigint, |
| x1 double precision, |
| x2 double precision, |
| y double precision |
| ); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (0, 0.400000000000000022, -0.699999999999999956, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (1, -1.5, -1, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (2, -1.39999999999999991, -0.900000000000000022, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (3, -1.30000000000000004, -1.19999999999999996, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (4, -1.10000000000000009, -0.200000000000000011, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (5, -1.19999999999999996, -0.400000000000000022, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (6, -0.5, 1.19999999999999996, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (7, -1.5, 2.10000000000000009, 0); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (8, 1, 1, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (9, 1.30000000000000004, 0.800000000000000044, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (10, 1.19999999999999996, 0.5, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (11, 0.200000000000000011, -2, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (12, 0.5, -2.39999999999999991, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (13, 0.200000000000000011, -2.29999999999999982, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (14, 0, -2.70000000000000018, 1); |
| INSERT INTO kernel_data (index, x1, x2, y) VALUES (15, 1.30000000000000004, 2.10000000000000009, 1); |
| |
| -- verify poly kernel mapping dimensions |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'poly_mapping', |
| 'label', |
| 'ind', |
| 'poly', |
| 'n_components=3, fit_intercept=true', |
| NULL, |
| 'max_iter=2'); |
| SELECT svm_predict('poly_mapping','svm_test_data', 'id', 'svm_test_poly_mapping'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 4, |
| 'The dimension of the coefficients must be equal to n_components + 1 (4)!') |
| FROM poly_mapping; |
| |
| -- verify poly kernel with grouping |
| -- verify partial string support in kernel specification |
| SELECT svm_regression( |
| 'abalone_train_small', |
| 'svr_mdl_poly', |
| 'rings', |
| 'ARRAY[1,diameter,shell,shucked,length]', |
| 'po', |
| 'degree=2, n_components=10, fit_intercept=true', |
| 'sex', |
| 'max_iter=2, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05', |
| false); |
| SELECT svm_predict('svr_mdl_poly','abalone_train_small', 'id', 'svm_test_poly'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 11, |
| 'The dimension of the coefficients must be equal to n_components + 1 (11)!') |
| FROM svr_mdl_poly; |
| |
| SELECT svm_classification( |
| 'kernel_data', |
| 'm10', |
| 'y', |
| 'array[x1, x2]', |
| 'gaussian', |
| 'gamma=1, n_components=20, random_state=2', |
| NULL, |
| 'init_stepsize=1, max_iter=10'); |
| SELECT * FROM m10; |
| |
| -- Test for class weight |
| CREATE TABLE svm_unbalanced ( |
| index bigint, |
| x1 double precision, |
| x2 double precision, |
| y bigint, |
| y_text text |
| ); |
| COPY svm_unbalanced (index, x1, x2, y, y_text) FROM stdin delimiter '|'; |
| 0|2.43651804549486251|-0.917634620475113127|0|zero |
| 1|-0.792257628395183544|-1.60945293323425576|0|zero |
| 2|1.29811144398701783|-3.45230804532042423|0|zero |
| 3|2.61721764632472009|-1.14181035134265407|0|zero |
| 4|0.478558644085647855|-0.374055563216115106|0|zero |
| 5|2.19316190556746093|-3.09021106424648107|0|zero |
| 6|-0.483625806020261229|-0.576081532002623464|0|zero |
| 7|1.70065416350315601|-1.64983690097104629|0|zero |
| 8|-0.258642311325653629|-1.31678762688205753|0|zero |
| 9|0.0633206200733892471|0.87422282057373335|0|zero |
| 10|-1.65092876581938186|1.7170855647594212|0|zero |
| 11|1.35238608088919321|0.753741508352802292|0|zero |
| 12|1.35128392389661767|-1.02559178876149959|0|zero |
| 13|-0.184335338277972272|-1.40365415138860317|0|zero |
| 14|-0.40183211943902386|0.795533200107279015|0|zero |
| 15|-1.03749112758796347|-0.595130290283966024|0|zero |
| 16|-1.03075905017939906|-1.26780846224807942|0|zero |
| 17|-1.00686919625522853|-0.0189968983783520423|0|zero |
| 18|-1.67596552295291668|0.351623546725638225|0|zero |
| 19|2.48970326566480571|1.11306624086600348|0|zero |
| 20|-0.287753328542422415|-1.3314434461272544|0|zero |
| 21|-1.12073744062625646|2.53868190154161999|0|zero |
| 22|0.0762116321640434469|-0.955493469854030053|0|zero |
| 23|0.286373227001199049|3.15038270471826332|0|zero |
| 24|0.180238428722443722|0.925804664561128865|0|zero |
| 25|0.450255479933741265|-0.528374769740277972|0|zero |
| 26|-1.71377729703321036|-0.524014083619316229|0|zero |
| 27|-0.313341350062167179|0.879934786773296507|0|zero |
| 28|1.25847512081175728|1.39665312195533597|0|zero |
| 29|0.428380987881388176|1.32771174640609213|0|zero |
| 30|-1.1315969114949791|1.87930223284993181|0|zero |
| 31|0.769394730627013246|-0.447139252654073505|0|zero |
| 32|0.73277721980624555|-0.113357569531583588|0|zero |
| 33|1.69744408117714052|2.27972522463329819|0|zero |
| 34|3.27836310979974233|-2.09474450323220651|0|zero |
| 35|-2.16617070814438417|-0.756698794419676801|0|zero |
| 36|0.240055604171745707|1.31425338167433736|0|zero |
| 37|0.473452420862407852|-3.03330182373600454|0|zero |
| 38|-0.459306018942557737|1.24196196391086922|0|zero |
| 39|0.345142103046575111|1.14301677046803718|0|zero |
| 40|-0.333492213915538904|-0.301137103394996164|0|zero |
| 41|0.279842086482426478|0.615077470812384508|0|zero |
| 42|0.297449580190154605|0.178512968711188214|0|zero |
| 43|-1.00599342943354575|0.56634567948137915|0|zero |
| 44|0.182731906487155399|1.6942258618678796|0|zero |
| 45|1.7983768198522605|0.277734626225915771|0|zero |
| 46|-0.562927425135171244|-0.958095611181333573|0|zero |
| 47|0.635241531096169321|0.116010102522839123|0|zero |
| 48|-0.515780513356613346|0.065395285251370408|0|zero |
| 49|-0.930001265922193898|1.04704805110832844|0|zero |
| 50|-0.670692847178997353|1.8367615572082483|0|zero |
| 51|0.605237462686200045|0.890367784855600419|0|zero |
| 52|-1.64236776861156275|0.254073649588002159|0|zero |
| 53|1.11083467664441216|-1.43055090271190188|0|zero |
| 54|-0.399327759005433103|0.0489218200400378389|0|zero |
| 55|-2.05967598037013344|0.472739088063437674|0|zero |
| 56|1.2692409713775501|-1.28927391124797941|0|zero |
| 57|0.525818967996161013|-1.96842511685614774|0|zero |
| 58|-0.0580432638990766719|-2.42365853205494197|0|zero |
| 59|1.682126562353496|0.613350806905241686|0|zero |
| 60|-0.0369254338136675297|-1.16274242875373934|0|zero |
| 61|1.91063389523816496|2.95065262388210225|0|zero |
| 62|-2.78697279667012809|1.85424604567923046|0|zero |
| 63|2.44147612972335981|0.507017544861713687|0|zero |
| 64|-1.79890204850277913|1.29501797631603233|0|zero |
| 65|-0.271380453117225695|-0.905880941689885866|0|zero |
| 66|-1.84508720350044264|0.825806243964323006|0|zero |
| 67|1.18921029887902163|-0.935296094519687427|0|zero |
| 68|0.78086450561005627|-1.71651208443471415|0|zero |
| 69|1.20279154780701703|0.0698509476362183107|0|zero |
| 70|-0.279854657861023148|-0.152618808793717808|0|zero |
| 71|1.30332923550880198|1.12561745979751215|0|zero |
| 72|0.794197986529063815|0.206551814996079108|0|zero |
| 73|0.116731691869058879|0.927570392997786763|0|zero |
| 74|0.348741838768106827|1.02382711029672779|0|zero |
| 75|-0.465175160277089994|-3.65225664616070844|0|zero |
| 76|1.55823690278912119|3.28046947046138637|0|zero |
| 77|0.662046665352873154|-0.150232849925249656|0|zero |
| 78|-0.204667115844049508|-0.178581281662214819|0|zero |
| 79|0.0261141124500068982|-1.68302809312033252|0|zero |
| 80|-0.775641686880341852|-1.49554024147539444|0|zero |
| 81|0.373198742081655654|-0.444961728556294012|0|zero |
| 82|0.742816985966940679|-0.26205473961375142|0|zero |
| 83|1.47950278173186289|0.320300852003162662|0|zero |
| 84|3.28604959345460035|-2.8445413843366385|0|zero |
| 85|-0.970375032382362002|1.35223033747306642|0|zero |
| 86|3.79248856020959701|-0.37295216657319008|0|zero |
| 87|0.0655034897675836614|-0.339471363770407764|0|zero |
| 88|1.9971856688813876|-0.430961795214028331|0|zero |
| 89|1.02010475981715665|-0.479702398348006764|0|zero |
| 90|-1.9088381328689914|0.470321580695148234|0|zero |
| 91|0.754777220152989092|1.93983882379839256|0|zero |
| 92|-0.165670539625974472|-0.926043095568541363|0|zero |
| 93|0.844141644928539492|0.361105638356598369|0|zero |
| 94|0.420997615683958493|-0.109669055620916639|0|zero |
| 95|1.7405078549906543|0.554239074563585565|0|zero |
| 96|2.85698806251147186|1.66658504784075689|0|zero |
| 97|0.988574694150315292|-2.44115751092438593|0|zero |
| 98|0.903478920443443689|0.630423305470589446|0|zero |
| 99|1.2164275092053336|1.56666314206088808|0|zero |
| 100|1.79956090410553671|2.41200280922520394|1|one |
| 101|1.71884728449045499|2.97743903750451722|1|one |
| 102|1.33402416674137592|1.1196557198006083|1|one |
| 103|1.1746393670879498|1.55472220791847571|1|one |
| 104|1.4404423007201359|2.97803945185182073|1|one |
| 105|1.83675025096090794|1.32866210531128193|1|one |
| 106|2.55719148838989607|1.7067380305892037|1|one |
| 107|1.38157331172930142|2.43791946382464975|1|one |
| 108|2.31168108828901619|1.7825216585223862|1|one |
| 109|2.70377000012061419|2.06455078985536256|1|one |
| \. |
| |
| DROP TABLE IF EXISTS svm_out, svm_out_summary; |
| SELECT svm_classification( |
| 'svm_unbalanced', |
| 'svm_out', |
| 'y', |
| 'ARRAY[1, x1, x2]', |
| 'linear', |
| NULL, |
| NULL, |
| 'max_iter=1000, init_stepsize=0.1, class_weight=balanced' |
| ); |
| |
| DROP TABLE IF EXISTS svm_predict_out; |
| SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out'); |
| |
| -- we check if the accuracy in prediction the unbalanced class is relatively |
| -- good. Without the class weight, this can go as low as 50%. |
| SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced class is too low') |
| FROM svm_unbalanced JOIN svm_predict_out |
| using (index) |
| WHERE y = prediction and y = 1; |
| |
| -- Test case with class_weight specified as a mapping. svm_unbalanced has |
| -- unbalanced data with 10x more examples for class 0 compared to 1. A |
| -- mapping with {1:10, 0:1} should be the same as balanced. |
| DROP TABLE IF EXISTS svm_out, svm_out_summary; |
| SELECT svm_classification( |
| 'svm_unbalanced', |
| 'svm_out', |
| 'y', |
| 'ARRAY[1, x1, x2]', |
| 'linear', |
| NULL, |
| NULL, |
| 'max_iter=1000, init_stepsize=0.1, class_weight={1:10, 0:1}' |
| ); |
| |
| DROP TABLE IF EXISTS svm_predict_out; |
| SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out'); |
| -- we check if the accuracy in prediction the unbalanced class is relatively |
| -- good. Without the class weight, this can go as low as 50%. |
| SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced numeric class with mapping class_weight is too low') |
| FROM svm_unbalanced JOIN svm_predict_out |
| using (index) |
| WHERE y = prediction and y = 1; |
| |
| -- Test case for class_weight with text class values. |
| DROP TABLE IF EXISTS svm_out, svm_out_summary; |
| SELECT svm_classification( |
| 'svm_unbalanced', |
| 'svm_out', |
| 'y_text', |
| 'ARRAY[1, x1, x2]', |
| 'linear', |
| NULL, |
| NULL, |
| 'max_iter=1000, init_stepsize=0.1, class_weight={zero:1, one:10}' |
| ); |
| |
| DROP TABLE IF EXISTS svm_predict_out; |
| SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out'); |
| |
| -- we check if the accuracy in prediction the unbalanced class is relatively |
| -- good. Without the class weight, this can go as low as 50%. |
| SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced text class with mapping class_weight is too low') |
| FROM svm_unbalanced JOIN svm_predict_out |
| using (index) |
| WHERE y_text = prediction and y_text = 'one'; |
| |
| -- Cross validation tests |
| SELECT svm_one_class( |
| 'svm_normalized', |
| 'svm_model_expression2', |
| 'ind', |
| 'gaussian', |
| NULL, |
| NULL, |
| 'init_stepsize=0.01, max_iter=3, lambda=[0.0002, 0.2], ' |
| 'n_folds=3, epsilon = [0.003, 0.2]' |
| ); |
| \x on |
| SELECT * FROM svm_model_expression2; |
| SELECT * FROM svm_model_expression2_summary; |
| \x off |
| |
| SELECT svm_predict('svm_model_expression2', 'svm_test_normalized', 'id', 'svm_test_model_expression2'); |
| SELECT svm_regression( |
| 'svr_train_data', |
| 'm1', |
| 'label', |
| 'ind', |
| 'poly', |
| NULL, |
| NULL, |
| 'init_stepsize=0.01, max_iter=3, lambda=[0.0002, 0.2], ' |
| 'n_folds=3, epsilon = [0.003, 0.2]', |
| true); |
| SELECT svm_predict('m1','svm_test_data', 'id', 'svm_test_8'); |
| |
| SELECT svm_regression( |
| 'svr_train_data', |
| 'm2', |
| 'label', |
| 'ind', |
| NULL,NULL,NULL, |
| 'init_stepsize=0.01, max_iter=2, lambda=[0.0002, 0.2], n_folds=3', |
| false); |
| -- check which lambda is selected |
| SELECT reg_params FROM m2_summary; |
| |
| -- epsilon values are ignored |
| -- the validation table only contains |
| -- init_stepsize and lambda |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm3', |
| 'label', |
| 'ind', |
| NULL,NULL,NULL, |
| 'init_stepsize=[0.01, 1], max_iter=3, lambda=[20, 0.0002, 0.02], ' |
| 'n_folds=3, epsilon=[0.1, 1], validation_result=val_res'); |
| SELECT * FROM val_res; |
| |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm4', |
| 'label', |
| 'ind', |
| NULL,NULL,NULL, |
| 'init_stepsize=0.01, max_iter=20, lambda=[20, 0.0002, 0.02], ' |
| 'n_folds=3, validation_result=val_res2'); |
| SELECT * FROM val_res; |
| -- check which lambda is selected |
| SELECT reg_params FROM m1_summary; |
| SELECT svm_predict('m1','svm_test_data', 'id', 'svm_test_reg_params'); |
| |
| -- verify poly kernel with cross validation |
| SELECT svm_classification( |
| 'svm_train_data', |
| 'm9', |
| 'label', |
| 'ind', |
| 'poly', |
| 'n_components=3', |
| NULL, |
| 'max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5]'); |
| SELECT svm_predict('m9','svm_test_data', 'id', 'svm_test_9'); |
| SELECT |
| assert( |
| array_upper(coef, 1) = 3, |
| 'The dimension of the coefficients must be equal to n_components (3)!') |
| FROM m9; |