blob: a8224014ee4766899a37754222563293879f8b4f [file] [log] [blame]
/* -----------------------------------------------------------------------------
* Test Linear Support Vector Machine
* -------------------------------------------------------------------------- */
CREATE OR REPLACE FUNCTION __svm_target_cl_func(ind float8[])
RETURNS float8 AS $$
BEGIN
IF (ind[1] > 0 AND ind[2] < 0) THEN RETURN 1; END IF;
RETURN -1;
END
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION __svr_target_cl_func(ind float8[])
RETURNS float8 AS $$
BEGIN
RETURN 1*ind[1] + 2*ind[2];
END
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION __svm_random_ind(d INT)
RETURNS float8[] AS $$
DECLARE
ret float8[];
BEGIN
FOR i IN 1..(d-1) LOOP
ret[i] = RANDOM() * 40 - 20;
END LOOP;
IF (RANDOM() > 0.5) THEN
ret[d] = 10;
ELSE
ret[d] = -10;
END IF;
RETURN ret;
END
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION svm_generate_cls_data(
output_table text, num int, dim int)
RETURNS VOID AS $$
DECLARE
temp_table text;
BEGIN
temp_table := 'madlib_temp_' || output_table;
EXECUTE '
CREATE TABLE ' || temp_table || ' AS
SELECT
subq.val AS id,
__svm_random_ind(' || dim || ') AS ind
FROM
(SELECT generate_series(1, ' || num || ') AS val) subq';
EXECUTE '
CREATE TABLE ' || output_table || ' AS
SELECT id, ind, __svm_target_cl_func(ind) AS label
FROM ' || temp_table;
END
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION svr_generate_cls_data(
output_table text, num int, dim int)
RETURNS VOID AS $$
DECLARE
temp_table text;
BEGIN
temp_table := 'madlib_temp_' || output_table;
EXECUTE '
CREATE TABLE ' || temp_table || ' AS
SELECT
subq.val AS id,
__svm_random_ind(' || dim || ') AS ind
FROM
(SELECT generate_series(1, ' || num || ') AS val) subq';
EXECUTE '
CREATE TABLE ' || output_table || ' AS
SELECT id, ind, __svr_target_cl_func(ind) AS label
FROM ' || temp_table;
END
$$ LANGUAGE plpgsql;
SELECT svm_generate_cls_data('svm_train_data', 1000, 4);
SELECT svm_generate_cls_data('svm_test_data', 1000, 4);
SELECT svr_generate_cls_data('svr_train_data', 1000, 4);
SELECT svr_generate_cls_data('svr_test_data', 1000, 4);
-- check the default values
SELECT svm_regression(
'svr_train_data',
'svr_model',
'label',
'ind');
\x on
SELECT * FROM svr_model;
SELECT * FROM svr_model_summary;
\x off
SELECT
assert(
norm1(coef) < 4,
'optimal coef should be close to [1, 2, 0, 0]!')
FROM svr_model;
-- check the use of l1 norm
SELECT svm_regression(
'svr_train_data',
'svr_model2',
'label',
'ind',
NULL,
NULL,
NULL,
'init_stepsize=0.01, max_iter=50, lambda=2, norm=l2, epsilon=0.01',
false);
SELECT svm_predict('svr_model2', 'svr_train_data', 'id', 'svr_test_result');
\x on
SELECT * FROM svr_model2;
\x off
SELECT
assert(
avg(subq.err) < 0.1,
'prediction error is too large!')
FROM
(
SELECT
train.id,
abs(train.label - test.prediction) AS err
FROM svr_train_data AS train, svr_test_result AS test
WHERE train.id = test.id
) AS subq;
-- Example usage for LINEAR classification, replace the above by
SELECT svm_classification(
'svm_train_data',
'lclss',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, --grouping_col
'max_iter=10, tolerance=0' --optim_params
);
SELECT * FROM lclss;
SELECT * FROM lclss_summary;
SELECT svm_predict('lclss', 'svm_test_data', 'id', 'svm_test_predict');
-- checking correctness with pre-conditioning
CREATE TABLE svm_normalized AS
SELECT
id,
array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind,
label
FROM svm_train_data,
(
SELECT ARRAY[avg(ind[1]),avg(ind[2]),
avg(ind[3]),avg(ind[4])] AS ind_avg
FROM svm_train_data
) AS svm_ind_avg,
(
SELECT ARRAY[stddev(ind[1]),stddev(ind[2]),
stddev(ind[3]),stddev(ind[4])] AS ind_stddev
FROM svm_train_data
) AS svm_ind_stddev
ORDER BY random();
CREATE TABLE svm_test_normalized AS
SELECT
id,
array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind,
label
FROM svm_test_data,
(
SELECT ARRAY[avg(ind[1]),avg(ind[2]),
avg(ind[3]),avg(ind[4])] AS ind_avg
FROM svm_test_data
) AS svm_test_ind_avg,
(
SELECT ARRAY[stddev(ind[1]),stddev(ind[2]),
stddev(ind[3]),stddev(ind[4])] AS ind_stddev
FROM svm_test_data
) AS svm_test_ind_stddev;
----------------------------------------------------------------
-- serial
-- learning
SELECT svm_classification(
'svm_normalized',
'svm_model',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, -- grouping_col
'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=0',
false -- verbose
);
\x on
SELECT * FROM svm_model;
SELECT * FROM svm_model_summary;
\x off
-- l2
SELECT svm_classification(
'svm_normalized',
'svm_model_small_norm2',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, --grouping_col
'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=1'
);
\x on
SELECT * FROM svm_model_small_norm2;
\x off
SELECT
assert(
norm2(l2.coef) < norm2(noreg.coef) OR
(
(norm2(l2.coef)-norm2(noreg.coef))/norm2(noreg.coef) < 0.1 AND
l2.loss < noreg.loss
),
'l2 regularization should produce coef with smaller l2 norm!')
FROM svm_model AS noreg, svm_model_small_norm2 AS l2;
-- l1 makes sprase models
SELECT svm_classification(
'svm_normalized',
'svm_model_very_sparse',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, --grouping_col
'init_stepsize=0.03, decay_factor=1, max_iter=5, tolerance=0, lambda=1, norm=L1'
);
\x on
SELECT * FROM svm_model_very_sparse;
\x off
SELECT
assert(
count(*) > 0,
'The model is supposed to be sparse with reg=1')
FROM
(
SELECT unnest(coef) AS w_i FROM svm_model_very_sparse
) subq
WHERE w_i != 0;
-- predicting
SELECT svm_predict('svm_model','svm_test_normalized', 'id', 'svm_test_predict2');
-- calculating accuracy
-- the accuracy is not guaranteed to be high because the stepsize & decay_factor
-- depend on the actual number of segments
SELECT
count(*) AS misclassification_count
FROM svm_test_predict2 NATURAL JOIN svm_test_normalized
WHERE prediction <> label;
----------------------------------------------------------------
-- decay factor non-zero
-- learning
SELECT svm_classification(
'svm_normalized',
'svm_model_decay_factor_non_zero',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, --grouping_col
'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda={0.001}',
false -- verbose
);
SELECT norm_of_gradient FROM svm_model_decay_factor_non_zero;
-- predicting
CREATE TABLE svm_test_predict_decay_factor_nonzero AS
SELECT
svm_test_normalized.id,
CASE WHEN array_dot(coef, ind) >= 0 THEN 1 ELSE -1 END AS prediction,
label
FROM svm_test_normalized, svm_model_decay_factor_non_zero;
-- stats for info
SELECT count(*) AS misclassification_count
FROM svm_test_predict_decay_factor_nonzero
WHERE prediction <> label;
-----------------------------------------------------------------
-- labels that are not just 1,-1
CREATE TABLE svm_normalized_fancy_label AS
SELECT
id,
array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind,
CASE when label = 1 THEN 'YES'
ELSE 'NO'
END AS label,
(id % 4) AS gid
FROM svm_train_data,
(
SELECT ARRAY[avg(ind[1]),avg(ind[2]),
avg(ind[3]),avg(ind[4])] AS ind_avg
FROM svm_train_data
) AS svm_ind_avg,
(
SELECT ARRAY[stddev(ind[1]),stddev(ind[2]),
stddev(ind[3]),stddev(ind[4])] AS ind_stddev
FROM svm_train_data
) AS svm_ind_stddev
ORDER BY random();
INSERT INTO svm_normalized_fancy_label VALUES
(1001, ARRAY[NULL,1,1,1,1]::float8[], 'YES', 1001 % 4),
(1002, ARRAY[5,1,1,1,1]::float8[], NULL, 1002 % 4),
(1003, ARRAY[5,1,NULL,1,1]::float8[], NULL, 1003 % 4);
CREATE TABLE svm_test_normalized_fancy_label AS
SELECT
id,
array_append(array_div(array_sub(ind, ind_avg), ind_stddev), 1::FLOAT8) AS ind,
CASE when label = 1 THEN 'YES'
ELSE 'NO'
END AS label,
(id % 4) as gid
FROM svm_test_data,
(
SELECT ARRAY[avg(ind[1]),avg(ind[2]),
avg(ind[3]),avg(ind[4])] AS ind_avg
FROM svm_test_data
) AS svm_test_ind_avg,
(
SELECT ARRAY[stddev(ind[1]),stddev(ind[2]),
stddev(ind[3]),stddev(ind[4])] AS ind_stddev
FROM svm_test_data
) AS svm_test_ind_stddev;
INSERT INTO svm_test_normalized_fancy_label VALUES
(1001, ARRAY[NULL,1,1,1,1]::float8[], 'YES', 1001 % 4);
-- training
SELECT svm_classification(
'svm_normalized_fancy_label',
'svm_model_fancy_label',
'label',
'ind',
NULL, -- kernel_func
NULL, -- kernel_pararms
'gid', --grouping_col
'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda=0.001',
TRUE -- verbose
);
\x on
SELECT * FROM svm_model_fancy_label;
SELECT * FROM svm_model_fancy_label_summary;
\x off
SELECT assert(count(*)=4, '4 group exist') FROM svm_model_fancy_label;
-- SELECT assert(total_rows_skipped=3, 'total_rows_skipped is wrong')
-- FROM svm_model_fancy_label_summary;
SELECT svm_predict('svm_model_fancy_label', 'svm_test_normalized_fancy_label', 'id', 'svm_test_fancy_label');
SELECT o.id, label, prediction, o.gid FROM svm_test_predict p, svm_test_normalized_fancy_label o where o.id = p.id;
-- calculating accuracy
-- the accuracy is not guaranteed to be high because the stepsize & decay_factor
-- depend on the actual number of segments
-- SELECT
-- count(*) AS misclassification_count
-- FROM svm_test_predict NATURAL JOIN svm_test_normalized_fancy_label
-- WHERE prediction <> label;
-- tests for depend varname being expression
SELECT svm_classification(
'svm_normalized',
'svm_model_expression',
'label>(ind[2]+ind[4])',
'ARRAY[ind[1],ind[3],ind[5]]',
NULL, -- kernel_func
NULL, -- kernel_pararms
NULL, --grouping_col
'init_stepsize=0.03, decay_factor=0.9, max_iter=5, tolerance=0, lambda=0.001',
false -- verbose
);
\x on
SELECT * FROM svm_model_expression;
SELECT * FROM svm_model_expression_summary;
\x off
SELECT svm_one_class(
'svm_normalized',
'svm_model_expression1',
'ind',
'gaussian'
);
\x on
SELECT * FROM svm_model_expression1;
SELECT * FROM svm_model_expression1_summary;
\x off
CREATE TABLE abalone_train_small_tmp (
sex TEXT,
id SERIAL NOT NULL,
length DOUBLE PRECISION,
diameter DOUBLE PRECISION,
height DOUBLE PRECISION,
whole DOUBLE PRECISION,
shucked DOUBLE PRECISION,
viscera DOUBLE PRECISION,
shell DOUBLE PRECISION,
rings INTEGER);
INSERT INTO abalone_train_small_tmp(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES
(1040,'F',0.66,0.475,0.18,1.3695,0.641,0.294,0.335,6),
(3160,'F',0.34,0.255,0.085,0.204,0.097,0.021,0.05,6),
(3984,'F',0.585,0.45,0.125,0.874,0.3545,0.2075,0.225,6),
(861,'F',0.595,0.475,0.16,1.1405,0.547,0.231,0.271,6),
(932,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
(1585,'F',0.515,0.375,0.11,0.6065,0.3005,0.131,0.15,6),
(3187,'F',0.47,0.36,0.11,0.4965,0.237,0.127,0.13,6),
(3202,'F',0.385,0.3,0.1,0.2725,0.1115,0.057,0.08,6),
(949,'F',0.475,0.36,0.12,0.5915,0.3245,0.11,0.127,6),
(2582,'F',0.53,0.42,0.17,0.828,0.41,0.208,0.1505,6),
(2551,'I',0.28,0.22,0.08,0.1315,0.066,0.024,0.03,5),
(1246,'I',0.385,0.28,0.09,0.228,0.1025,0.042,0.0655,5),
(819,'I',0.35,0.25,0.07,0.18,0.0655,0.048,0.054,6),
(297,'I',0.275,0.205,0.075,0.1105,0.045,0.0285,0.035,6),
(3630,'I',0.27,0.205,0.05,0.084,0.03,0.0185,0.029,6),
(2196,'I',0.26,0.215,0.08,0.099,0.037,0.0255,0.045,5),
(2343,'I',0.255,0.185,0.07,0.075,0.028,0.018,0.025,6),
(49,'I',0.325,0.245,0.07,0.161,0.0755,0.0255,0.045,6),
(2185,'I',0.32,0.235,0.08,0.1485,0.064,0.031,0.045,6),
(2154,'I',0.28,0.2,0.075,0.1225,0.0545,0.0115,0.035,5),
(1996,'I',0.32,0.24,0.07,0.133,0.0585,0.0255,0.041,6),
(126,'I',0.27,0.195,0.06,0.073,0.0285,0.0235,0.03,5),
(1227,'I',0.35,0.27,0.075,0.215,0.1,0.036,0.065,6),
(3969,'I',0.375,0.29,0.095,0.2875,0.123,0.0605,0.08,6),
(2505,'I',0.31,0.24,0.105,0.2885,0.118,0.065,0.083,6),
(2039,'I',0.28,0.215,0.08,0.132,0.072,0.022,0.033,5),
(829,'I',0.41,0.325,0.1,0.394,0.208,0.0655,0.106,6),
(3197,'I',0.325,0.245,0.075,0.1495,0.0605,0.033,0.045,5),
(1447,'I',0.44,0.34,0.105,0.369,0.164,0.08,0.1015,5),
(2821,'I',0.375,0.285,0.09,0.2545,0.119,0.0595,0.0675,6),
(1828,'I',0.34,0.26,0.085,0.1885,0.0815,0.0335,0.06,6),
(2002,'I',0.36,0.27,0.085,0.2185,0.1065,0.038,0.062,6),
(785,'I',0.215,0.155,0.06,0.0525,0.021,0.0165,0.015,5),
(2199,'I',0.27,0.19,0.08,0.081,0.0265,0.0195,0.03,6),
(3527,'I',0.335,0.26,0.085,0.192,0.097,0.03,0.054,6),
(466,'I',0.175,0.125,0.05,0.0235,0.008,0.0035,0.008,5),
(425,'I',0.26,0.2,0.07,0.092,0.037,0.02,0.03,6),
(1825,'I',0.185,0.135,0.04,0.027,0.0105,0.0055,0.009,5),
(3815,'I',0.38,0.275,0.095,0.2425,0.106,0.0485,0.21,6),
(2503,'I',0.285,0.21,0.07,0.109,0.044,0.0265,0.033,5),
(3998,'I',0.36,0.27,0.09,0.2075,0.098,0.039,0.062,6),
(333,'I',0.3,0.22,0.08,0.121,0.0475,0.042,0.035,5),
(1837,'I',0.415,0.31,0.09,0.2815,0.1245,0.0615,0.085,6),
(2813,'I',0.24,0.17,0.05,0.0545,0.0205,0.016,0.0155,5),
(930,'I',0.44,0.345,0.13,0.4495,0.209,0.0835,0.134,6),
(1436,'I',0.385,0.3,0.09,0.247,0.1225,0.044,0.0675,5),
(3972,'I',0.4,0.295,0.095,0.252,0.1105,0.0575,0.066,6),
(1433,'I',0.365,0.255,0.08,0.1985,0.0785,0.0345,0.053,5),
(1252,'I',0.405,0.285,0.09,0.2645,0.1265,0.0505,0.075,6),
(3439,'I',0.43,0.335,0.105,0.378,0.188,0.0785,0.09,6),
(1250,'I',0.395,0.27,0.1,0.2985,0.1445,0.061,0.082,5),
(2865,'I',0.31,0.23,0.07,0.1245,0.0505,0.0265,0.038,6),
(3411,'I',0.415,0.31,0.105,0.3595,0.167,0.083,0.0915,6),
(1539,'I',0.355,0.27,0.075,0.1775,0.079,0.0315,0.054,6),
(1990,'I',0.28,0.21,0.075,0.1195,0.053,0.0265,0.03,6),
(1771,'I',0.455,0.335,0.105,0.422,0.229,0.0865,0.1,6),
(2291,'I',0.325,0.27,0.1,0.185,0.08,0.0435,0.065,6),
(3381,'I',0.19,0.13,0.045,0.0265,0.009,0.005,0.009,5),
(1545,'I',0.37,0.27,0.095,0.2175,0.097,0.046,0.065,6),
(652,'I',0.335,0.245,0.09,0.1665,0.0595,0.04,0.06,6),
(3434,'I',0.365,0.27,0.105,0.2155,0.0915,0.0475,0.063,6),
(2004,'I',0.375,0.28,0.08,0.226,0.105,0.047,0.065,6),
(2000,'I',0.35,0.25,0.07,0.1605,0.0715,0.0335,0.046,6),
(3946,'I',0.235,0.175,0.065,0.0615,0.0205,0.02,0.019,6),
(177,'I',0.315,0.21,0.06,0.125,0.06,0.0375,0.035,5),
(920,'I',0.41,0.31,0.09,0.3335,0.1635,0.061,0.091,6),
(3437,'I',0.38,0.275,0.095,0.2505,0.0945,0.0655,0.075,6),
(2630,'I',0.33,0.24,0.075,0.163,0.0745,0.033,0.048,6),
(1092,'I',0.45,0.33,0.11,0.3685,0.16,0.0885,0.102,6),
(3476,'I',0.4,0.315,0.085,0.2675,0.116,0.0585,0.0765,6),
(3526,'I',0.33,0.23,0.085,0.1695,0.079,0.026,0.0505,6),
(1534,'I',0.295,0.215,0.07,0.121,0.047,0.0155,0.0405,6),
(921,'I',0.415,0.33,0.09,0.3595,0.17,0.081,0.09,6),
(2206,'I',0.275,0.22,0.08,0.1365,0.0565,0.0285,0.042,6),
(1218,'I',0.315,0.23,0.08,0.1375,0.0545,0.031,0.0445,5),
(1998,'I',0.335,0.25,0.08,0.1695,0.0695,0.044,0.0495,6),
(2455,'I',0.275,0.2,0.065,0.092,0.0385,0.0235,0.027,5),
(2548,'I',0.23,0.18,0.05,0.064,0.0215,0.0135,0.02,5),
(3996,'I',0.245,0.175,0.055,0.0785,0.04,0.018,0.02,5),
(3408,'I',0.35,0.265,0.08,0.192,0.081,0.0465,0.053,6),
(3907,'M',0.245,0.18,0.065,0.0635,0.0245,0.0135,0.02,4),
(3850,'M',0.385,0.3,0.115,0.3435,0.1645,0.085,0.1025,6),
(124,'M',0.37,0.265,0.075,0.214,0.09,0.051,0.07,6),
(2583,'M',0.53,0.41,0.14,0.681,0.3095,0.1415,0.1835,6),
(526,'M',0.175,0.125,0.04,0.024,0.0095,0.006,0.005,4),
(2184,'M',0.495,0.4,0.155,0.8085,0.2345,0.1155,0.35,6),
(2132,'M',0.32,0.24,0.08,0.18,0.08,0.0385,0.055,6),
(651,'M',0.255,0.18,0.065,0.079,0.034,0.014,0.025,5),
(612,'M',0.195,0.145,0.05,0.032,0.01,0.008,0.012,4),
(958,'M',0.5,0.39,0.135,0.6595,0.3145,0.1535,0.1565,6),
(3174,'M',0.35,0.265,0.09,0.2265,0.0995,0.0575,0.065,6),
(265,'M',0.27,0.2,0.08,0.1205,0.0465,0.028,0.04,6),
(519,'M',0.325,0.23,0.09,0.147,0.06,0.034,0.045,4),
(2382,'M',0.155,0.115,0.025,0.024,0.009,0.005,0.0075,5),
(698,'M',0.28,0.205,0.1,0.1165,0.0545,0.0285,0.03,5),
(2381,'M',0.175,0.135,0.04,0.0305,0.011,0.0075,0.01,5),
(516,'M',0.27,0.195,0.08,0.1,0.0385,0.0195,0.03,6),
(831,'M',0.415,0.305,0.1,0.325,0.156,0.0505,0.091,6),
(3359,'M',0.285,0.215,0.075,0.106,0.0415,0.023,0.035,5);
CREATE TABLE abalone_train_small AS (
SELECT * FROM abalone_train_small_tmp
);
-- create epsilon input table
CREATE TABLE abalone_eps (
sex TEXT,
epsilon DOUBLE PRECISION);
INSERT INTO abalone_eps(sex, epsilon) VALUES
('I', 0.2),
('M', 0.05);
SELECT svm_classification(
'svm_train_data',
'm5',
'label',
'ind',
NULL,NULL,NULL,
'init_stepsize=0.01, max_iter=20, lambda=0.000002');
SELECT svm_predict('m5','svm_test_data', 'id', 'svm_test_5');
-- accuracy without cv
SELECT
count(*) AS misclassification_count
FROM svm_test_5 NATURAL JOIN svm_test_data
WHERE prediction <> label;
-- SVM with kernels -----------------------------------------------------------
-- verify guassian kernel mapping dimensions
SELECT svm_classification(
'svm_train_data',
'm6',
'label',
'ind',
'gaussian',
'n_components=3, fit_intercept=false',
NULL,
'max_iter=2');
SELECT svm_predict('m6','svm_test_data', 'id', 'svm_test_6');
SELECT
assert(
array_upper(coef, 1) = 3,
'The dimension of the coefficients must be equal to n_components (3)!')
FROM m6;
-- verify gaussian kernel with grouping
-- verify partial string support in kernel specification
SELECT svm_regression(
'abalone_train_small',
'svr_mdl_m',
'rings',
'ARRAY[1,diameter,shell,shucked,length]',
'gau',
'n_components=10',
'sex',
'max_iter=2, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05',
false);
SELECT svm_predict('svr_mdl_m','abalone_train_small', 'id', 'svm_test_mdl_m');
SELECT
assert(
array_upper(coef, 1) = 10,
'The dimension of the coefficients must be equal to n_components (10)!')
FROM svr_mdl_m;
-- verify guassian kernel with cross validation
SELECT svm_classification(
'svm_train_data',
'm7',
'label',
'ind',
'gaussian',
'n_components=3, fit_intercept=true',
NULL,
'init_stepsize=[0.01, 0.1], max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5], validation_result=m7_cv');
SELECT * FROM m7_cv;
SELECT svm_predict('m7','svm_test_data', 'id', 'svm_test_7');
SELECT
assert(
array_upper(coef, 1) = 4,
'The dimension of the coefficients must be equal to n_components + 1 (4)!')
FROM m7;
-- verify guassian kernel with out-of-memory support
SELECT svm_classification(
'svm_train_data',
'm8',
'label',
'ind',
'gaussian',
'n_components=3, fit_in_memory=False',
NULL,
'max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5]');
SELECT
assert(
array_upper(coef, 1) = 3,
'The dimension of the coefficients must be equal to n_components (3)!')
FROM m8;
CREATE TABLE kernel_data (
index bigint,
x1 double precision,
x2 double precision,
y double precision
);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (0, 0.400000000000000022, -0.699999999999999956, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (1, -1.5, -1, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (2, -1.39999999999999991, -0.900000000000000022, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (3, -1.30000000000000004, -1.19999999999999996, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (4, -1.10000000000000009, -0.200000000000000011, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (5, -1.19999999999999996, -0.400000000000000022, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (6, -0.5, 1.19999999999999996, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (7, -1.5, 2.10000000000000009, 0);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (8, 1, 1, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (9, 1.30000000000000004, 0.800000000000000044, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (10, 1.19999999999999996, 0.5, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (11, 0.200000000000000011, -2, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (12, 0.5, -2.39999999999999991, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (13, 0.200000000000000011, -2.29999999999999982, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (14, 0, -2.70000000000000018, 1);
INSERT INTO kernel_data (index, x1, x2, y) VALUES (15, 1.30000000000000004, 2.10000000000000009, 1);
-- verify poly kernel mapping dimensions
SELECT svm_classification(
'svm_train_data',
'poly_mapping',
'label',
'ind',
'poly',
'n_components=3, fit_intercept=true',
NULL,
'max_iter=2');
SELECT svm_predict('poly_mapping','svm_test_data', 'id', 'svm_test_poly_mapping');
SELECT
assert(
array_upper(coef, 1) = 4,
'The dimension of the coefficients must be equal to n_components + 1 (4)!')
FROM poly_mapping;
-- verify poly kernel with grouping
-- verify partial string support in kernel specification
SELECT svm_regression(
'abalone_train_small',
'svr_mdl_poly',
'rings',
'ARRAY[1,diameter,shell,shucked,length]',
'po',
'degree=2, n_components=10, fit_intercept=true',
'sex',
'max_iter=2, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05',
false);
SELECT svm_predict('svr_mdl_poly','abalone_train_small', 'id', 'svm_test_poly');
SELECT
assert(
array_upper(coef, 1) = 11,
'The dimension of the coefficients must be equal to n_components + 1 (11)!')
FROM svr_mdl_poly;
SELECT svm_classification(
'kernel_data',
'm10',
'y',
'array[x1, x2]',
'gaussian',
'gamma=1, n_components=20, random_state=2',
NULL,
'init_stepsize=1, max_iter=10');
SELECT * FROM m10;
-- Test for class weight
CREATE TABLE svm_unbalanced (
index bigint,
x1 double precision,
x2 double precision,
y bigint,
y_text text
);
COPY svm_unbalanced (index, x1, x2, y, y_text) FROM stdin delimiter '|';
0|2.43651804549486251|-0.917634620475113127|0|zero
1|-0.792257628395183544|-1.60945293323425576|0|zero
2|1.29811144398701783|-3.45230804532042423|0|zero
3|2.61721764632472009|-1.14181035134265407|0|zero
4|0.478558644085647855|-0.374055563216115106|0|zero
5|2.19316190556746093|-3.09021106424648107|0|zero
6|-0.483625806020261229|-0.576081532002623464|0|zero
7|1.70065416350315601|-1.64983690097104629|0|zero
8|-0.258642311325653629|-1.31678762688205753|0|zero
9|0.0633206200733892471|0.87422282057373335|0|zero
10|-1.65092876581938186|1.7170855647594212|0|zero
11|1.35238608088919321|0.753741508352802292|0|zero
12|1.35128392389661767|-1.02559178876149959|0|zero
13|-0.184335338277972272|-1.40365415138860317|0|zero
14|-0.40183211943902386|0.795533200107279015|0|zero
15|-1.03749112758796347|-0.595130290283966024|0|zero
16|-1.03075905017939906|-1.26780846224807942|0|zero
17|-1.00686919625522853|-0.0189968983783520423|0|zero
18|-1.67596552295291668|0.351623546725638225|0|zero
19|2.48970326566480571|1.11306624086600348|0|zero
20|-0.287753328542422415|-1.3314434461272544|0|zero
21|-1.12073744062625646|2.53868190154161999|0|zero
22|0.0762116321640434469|-0.955493469854030053|0|zero
23|0.286373227001199049|3.15038270471826332|0|zero
24|0.180238428722443722|0.925804664561128865|0|zero
25|0.450255479933741265|-0.528374769740277972|0|zero
26|-1.71377729703321036|-0.524014083619316229|0|zero
27|-0.313341350062167179|0.879934786773296507|0|zero
28|1.25847512081175728|1.39665312195533597|0|zero
29|0.428380987881388176|1.32771174640609213|0|zero
30|-1.1315969114949791|1.87930223284993181|0|zero
31|0.769394730627013246|-0.447139252654073505|0|zero
32|0.73277721980624555|-0.113357569531583588|0|zero
33|1.69744408117714052|2.27972522463329819|0|zero
34|3.27836310979974233|-2.09474450323220651|0|zero
35|-2.16617070814438417|-0.756698794419676801|0|zero
36|0.240055604171745707|1.31425338167433736|0|zero
37|0.473452420862407852|-3.03330182373600454|0|zero
38|-0.459306018942557737|1.24196196391086922|0|zero
39|0.345142103046575111|1.14301677046803718|0|zero
40|-0.333492213915538904|-0.301137103394996164|0|zero
41|0.279842086482426478|0.615077470812384508|0|zero
42|0.297449580190154605|0.178512968711188214|0|zero
43|-1.00599342943354575|0.56634567948137915|0|zero
44|0.182731906487155399|1.6942258618678796|0|zero
45|1.7983768198522605|0.277734626225915771|0|zero
46|-0.562927425135171244|-0.958095611181333573|0|zero
47|0.635241531096169321|0.116010102522839123|0|zero
48|-0.515780513356613346|0.065395285251370408|0|zero
49|-0.930001265922193898|1.04704805110832844|0|zero
50|-0.670692847178997353|1.8367615572082483|0|zero
51|0.605237462686200045|0.890367784855600419|0|zero
52|-1.64236776861156275|0.254073649588002159|0|zero
53|1.11083467664441216|-1.43055090271190188|0|zero
54|-0.399327759005433103|0.0489218200400378389|0|zero
55|-2.05967598037013344|0.472739088063437674|0|zero
56|1.2692409713775501|-1.28927391124797941|0|zero
57|0.525818967996161013|-1.96842511685614774|0|zero
58|-0.0580432638990766719|-2.42365853205494197|0|zero
59|1.682126562353496|0.613350806905241686|0|zero
60|-0.0369254338136675297|-1.16274242875373934|0|zero
61|1.91063389523816496|2.95065262388210225|0|zero
62|-2.78697279667012809|1.85424604567923046|0|zero
63|2.44147612972335981|0.507017544861713687|0|zero
64|-1.79890204850277913|1.29501797631603233|0|zero
65|-0.271380453117225695|-0.905880941689885866|0|zero
66|-1.84508720350044264|0.825806243964323006|0|zero
67|1.18921029887902163|-0.935296094519687427|0|zero
68|0.78086450561005627|-1.71651208443471415|0|zero
69|1.20279154780701703|0.0698509476362183107|0|zero
70|-0.279854657861023148|-0.152618808793717808|0|zero
71|1.30332923550880198|1.12561745979751215|0|zero
72|0.794197986529063815|0.206551814996079108|0|zero
73|0.116731691869058879|0.927570392997786763|0|zero
74|0.348741838768106827|1.02382711029672779|0|zero
75|-0.465175160277089994|-3.65225664616070844|0|zero
76|1.55823690278912119|3.28046947046138637|0|zero
77|0.662046665352873154|-0.150232849925249656|0|zero
78|-0.204667115844049508|-0.178581281662214819|0|zero
79|0.0261141124500068982|-1.68302809312033252|0|zero
80|-0.775641686880341852|-1.49554024147539444|0|zero
81|0.373198742081655654|-0.444961728556294012|0|zero
82|0.742816985966940679|-0.26205473961375142|0|zero
83|1.47950278173186289|0.320300852003162662|0|zero
84|3.28604959345460035|-2.8445413843366385|0|zero
85|-0.970375032382362002|1.35223033747306642|0|zero
86|3.79248856020959701|-0.37295216657319008|0|zero
87|0.0655034897675836614|-0.339471363770407764|0|zero
88|1.9971856688813876|-0.430961795214028331|0|zero
89|1.02010475981715665|-0.479702398348006764|0|zero
90|-1.9088381328689914|0.470321580695148234|0|zero
91|0.754777220152989092|1.93983882379839256|0|zero
92|-0.165670539625974472|-0.926043095568541363|0|zero
93|0.844141644928539492|0.361105638356598369|0|zero
94|0.420997615683958493|-0.109669055620916639|0|zero
95|1.7405078549906543|0.554239074563585565|0|zero
96|2.85698806251147186|1.66658504784075689|0|zero
97|0.988574694150315292|-2.44115751092438593|0|zero
98|0.903478920443443689|0.630423305470589446|0|zero
99|1.2164275092053336|1.56666314206088808|0|zero
100|1.79956090410553671|2.41200280922520394|1|one
101|1.71884728449045499|2.97743903750451722|1|one
102|1.33402416674137592|1.1196557198006083|1|one
103|1.1746393670879498|1.55472220791847571|1|one
104|1.4404423007201359|2.97803945185182073|1|one
105|1.83675025096090794|1.32866210531128193|1|one
106|2.55719148838989607|1.7067380305892037|1|one
107|1.38157331172930142|2.43791946382464975|1|one
108|2.31168108828901619|1.7825216585223862|1|one
109|2.70377000012061419|2.06455078985536256|1|one
\.
DROP TABLE IF EXISTS svm_out, svm_out_summary;
SELECT svm_classification(
'svm_unbalanced',
'svm_out',
'y',
'ARRAY[1, x1, x2]',
'linear',
NULL,
NULL,
'max_iter=1000, init_stepsize=0.1, class_weight=balanced'
);
DROP TABLE IF EXISTS svm_predict_out;
SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out');
-- we check if the accuracy in prediction the unbalanced class is relatively
-- good. Without the class weight, this can go as low as 50%.
SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced class is too low')
FROM svm_unbalanced JOIN svm_predict_out
using (index)
WHERE y = prediction and y = 1;
-- Test case with class_weight specified as a mapping. svm_unbalanced has
-- unbalanced data with 10x more examples for class 0 compared to 1. A
-- mapping with {1:10, 0:1} should be the same as balanced.
DROP TABLE IF EXISTS svm_out, svm_out_summary;
SELECT svm_classification(
'svm_unbalanced',
'svm_out',
'y',
'ARRAY[1, x1, x2]',
'linear',
NULL,
NULL,
'max_iter=1000, init_stepsize=0.1, class_weight={1:10, 0:1}'
);
DROP TABLE IF EXISTS svm_predict_out;
SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out');
-- we check if the accuracy in prediction the unbalanced class is relatively
-- good. Without the class weight, this can go as low as 50%.
SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced numeric class with mapping class_weight is too low')
FROM svm_unbalanced JOIN svm_predict_out
using (index)
WHERE y = prediction and y = 1;
-- Test case for class_weight with text class values.
DROP TABLE IF EXISTS svm_out, svm_out_summary;
SELECT svm_classification(
'svm_unbalanced',
'svm_out',
'y_text',
'ARRAY[1, x1, x2]',
'linear',
NULL,
NULL,
'max_iter=1000, init_stepsize=0.1, class_weight={zero:1, one:10}'
);
DROP TABLE IF EXISTS svm_predict_out;
SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out');
-- we check if the accuracy in prediction the unbalanced class is relatively
-- good. Without the class weight, this can go as low as 50%.
SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced text class with mapping class_weight is too low')
FROM svm_unbalanced JOIN svm_predict_out
using (index)
WHERE y_text = prediction and y_text = 'one';
-- Cross validation tests
SELECT svm_one_class(
'svm_normalized',
'svm_model_expression2',
'ind',
'gaussian',
NULL,
NULL,
'init_stepsize=0.01, max_iter=3, lambda=[0.0002, 0.2], '
'n_folds=3, epsilon = [0.003, 0.2]'
);
\x on
SELECT * FROM svm_model_expression2;
SELECT * FROM svm_model_expression2_summary;
\x off
SELECT svm_predict('svm_model_expression2', 'svm_test_normalized', 'id', 'svm_test_model_expression2');
SELECT svm_regression(
'svr_train_data',
'm1',
'label',
'ind',
'poly',
NULL,
NULL,
'init_stepsize=0.01, max_iter=3, lambda=[0.0002, 0.2], '
'n_folds=3, epsilon = [0.003, 0.2]',
true);
SELECT svm_predict('m1','svm_test_data', 'id', 'svm_test_8');
SELECT svm_regression(
'svr_train_data',
'm2',
'label',
'ind',
NULL,NULL,NULL,
'init_stepsize=0.01, max_iter=2, lambda=[0.0002, 0.2], n_folds=3',
false);
-- check which lambda is selected
SELECT reg_params FROM m2_summary;
-- epsilon values are ignored
-- the validation table only contains
-- init_stepsize and lambda
SELECT svm_classification(
'svm_train_data',
'm3',
'label',
'ind',
NULL,NULL,NULL,
'init_stepsize=[0.01, 1], max_iter=3, lambda=[20, 0.0002, 0.02], '
'n_folds=3, epsilon=[0.1, 1], validation_result=val_res');
SELECT * FROM val_res;
SELECT svm_classification(
'svm_train_data',
'm4',
'label',
'ind',
NULL,NULL,NULL,
'init_stepsize=0.01, max_iter=20, lambda=[20, 0.0002, 0.02], '
'n_folds=3, validation_result=val_res2');
SELECT * FROM val_res;
-- check which lambda is selected
SELECT reg_params FROM m1_summary;
SELECT svm_predict('m1','svm_test_data', 'id', 'svm_test_reg_params');
-- verify poly kernel with cross validation
SELECT svm_classification(
'svm_train_data',
'm9',
'label',
'ind',
'poly',
'n_components=3',
NULL,
'max_iter=2, n_folds=3, lambda=[0.01, 0.1, 0.5]');
SELECT svm_predict('m9','svm_test_data', 'id', 'svm_test_9');
SELECT
assert(
array_upper(coef, 1) = 3,
'The dimension of the coefficients must be equal to n_components (3)!')
FROM m9;