blob: 53e8362c6fe95ae50e55e1a16c558344f7327c0b [file] [log] [blame]
---------------------------------------------------------------------------
-- training
DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
CREATE TABLE dummy_dt_con_src (
id integer,
cat integer[],
con float8[],
y float8
);
INSERT INTO dummy_dt_con_src VALUES
(1, '{0}'::integer[], ARRAY[0], 0.5),
(2, '{0}'::integer[], ARRAY[1], 0.5),
(3, '{0}'::integer[], ARRAY[4], 0.5),
(4, '{0}'::integer[], ARRAY[5], 0.1),
(5, '{0}'::integer[], ARRAY[6], 0.1),
(6, '{1}'::integer[], ARRAY[9], 0.1);
------------------------------------------------------------
-- entropy
SELECT
assert(relative_error(_dst_compute_entropy(i, 10), 1.0) < 1e-6,
'expect 1 bit entropy, got ' || _dst_compute_entropy(i, 10)::text)
FROM
(SELECT generate_series(0, 1) AS i) seq
;
------------------------------------------------------------
-- con_splits
DROP TABLE IF EXISTS dummy_splits;
CREATE TABLE dummy_splits AS
SELECT _dst_compute_con_splits(
con,
9,
5::smallint
) as splits
FROM dummy_dt_con_src
;
SELECT * FROM dummy_splits;
---------------------------------------------------------------------------
-- cat encoding
SELECT
assert(
relative_error(_map_catlevel_to_int('{B}', '{A,B}', ARRAY[2], TRUE), ARRAY[1]) < 1e-6,
'wrong results in _map_catlevel_to_int()')
;
------------------------------------------------------------
-- test training aggregate manually
\x on
SELECT _print_decision_tree(_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint));
-------------------------------------------------------------------------
-- regression tree for single level
SELECT
*
-- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'),
-- assert(tree_depth = 1, 'dummy_dt (wrong tree_depth)'),
-- assert(relative_error(feature_indices, ARRAY[0,-1,-1]) = 0.,
-- 'dummy_dt (wrong feature_indices)'),
-- assert(feature_thresholds[1] = 4, 'dummy_dt (wrong feature_thresholds)'),
-- assert(is_categorical[1] = 0, 'dummy_dt (wrong is_categorical)'),
-- assert(relative_error(predictions[1:1][2:3], ARRAY[0.5, 0.1]) < 1e-6,
-- 'dummy_dt (wrong predictions)')
FROM (
SELECT (tree).*
FROM (
SELECT
_print_decision_tree(
(_dt_apply(
_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint),
_compute_leaf_stats(
_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint),
cat::integer[],
con::double precision[],
y::double precision,
1.0::double precision,
'{2}'::integer[],
(SELECT splits FROM dummy_splits)::BYTEA8,
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
10::smallint,
False::boolean,
1::integer
)).tree_state
) as tree
FROM dummy_dt_con_src
) q1
) q2
;
-------------------------------------------------------------------------
-- classification tree for multi-levels
DROP TABLE IF EXISTS dummy_dt_cat_src CASCADE;
CREATE TABLE dummy_dt_cat_src (
id integer,
cat integer[],
con float8[],
y integer,
weight float8
);
INSERT INTO dummy_dt_cat_src VALUES
(1, '{0, 0}'::integer[], ARRAY[100, 0], 0, 10),
(2, '{0, 0}'::integer[], ARRAY[200, 1], 0, 10),
(3, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 10),
(4, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1),
(5, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1),
(6, '{0, 1}'::integer[], ARRAY[10, 5], 1, 1),
(7, '{0, 2}'::integer[], ARRAY[10.10, 10], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[12, 9], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[13.4, 9], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[15, 9], 1, 1),
(9, '{1, 2}'::integer[], ARRAY[15.4, 9], 1, 0.9),
(10, '{1, 2}'::integer[], ARRAY[5.4, 10], 1, 0.9),
(11, '{1, 2}'::integer[], ARRAY[15.4, 10], 1, 0.9),
(11, '{1, 2}'::integer[], ARRAY[25.4, 1], 1, 0.9),
(12, '{1, 2}'::integer[], ARRAY[15.4, -20], 1, 0.9),
(13, '{1, 2}'::integer[], ARRAY[15.4, 200], 1, 0.9),
(14, '{1, 2}'::integer[], ARRAY[15.4, 0], 1, 0.9),
(15, '{1, 2}'::integer[], ARRAY[11.4, 1], 1, 0.9),
(16, '{1, 2}'::integer[], ARRAY[124, 2], 1, 0.9);
DROP TABLE IF EXISTS dummy_splits;
CREATE TABLE dummy_splits AS
SELECT _dst_compute_con_splits(
con,
9,
3::smallint
) as splits
FROM dummy_dt_cat_src
;
-- Add first iteration of tree in an output table
create table dummy_dt_output as
SELECT
1::integer as iteration,
(_dt_apply(
_initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint),
_compute_leaf_stats(
_initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint),
cat,
con,
y,
weight::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
0::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state as tree
FROM dummy_dt_cat_src
;
SELECT * FROM dummy_dt_output;
-- Train tree further
insert into dummy_dt_output
SELECT
2::integer as iteration,
(_dt_apply(
(SELECT tree from dummy_dt_output where iteration = 1),
_compute_leaf_stats(
(SELECT tree from dummy_dt_output where iteration = 1),
cat,
con,
y,
1.0::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state as tree
FROM dummy_dt_cat_src
;
-- Final iteration to wrap the training of tree
-- (this iteration ends training but does not expand tree further)
SELECT
*
-- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'),
-- assert(tree_depth = 2, 'dummy_dt (wrong tree_depth)'),
-- assert(relative_error(feature_indices, ARRAY[1,0,-2,-2,-2,-3,-3]) = 0.,
-- 'dummy_dt (wrong feature_indices)'),
-- assert(relative_error(feature_thresholds, ARRAY[1,12,0,0,0,0,0]) = 0.,
-- 'dummy_dt (wrong feature_thresholds)'),
-- assert(relative_error(is_categorical, ARRAY[1,0,0,0,0,0,0]) = 0.,
-- 'dummy_dt (wrong is_categorical)')
FROM (
SELECT (tree).*
FROM (
SELECT
_print_decision_tree(
(_dt_apply(
(SELECT tree from dummy_dt_output where iteration = 2),
_compute_leaf_stats(
(SELECT tree from dummy_dt_output where iteration = 2),
cat,
con,
y,
1.0::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state
) as tree
FROM dummy_dt_cat_src
) q1
) q2
;
--------------------------------------------------------------------------------
-- Validate tree train function ------------------------------------------------
DROP TABLE IF EXISTS dt_golf CASCADE;
CREATE TABLE dt_golf (
id integer NOT NULL,
"OUTLOOK" text,
temperature double precision,
humidity double precision,
"Cont_features" double precision[],
cat_features text[],
windy boolean,
class text
) ;
INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,"Cont_features",cat_features, windy,class) VALUES
(1, 'sunny', 85, 85,ARRAY[85, 85], ARRAY['a', 'b'], false, 'Don''t Play'),
(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['a', 'b'], true, 'Don''t Play'),
(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['a', 'b'], false, 'Play'),
(4, 'rain', 70, NULL, ARRAY[70, 96], ARRAY['a', 'b'], false, 'Play'),
(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['a', 'b'], false, 'Play'),
(6, 'rain', NULL, 70, ARRAY[65, 70], ARRAY['a', 'b'], true, 'Don''t Play'),
(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['c', 'b'], NULL , 'Play'),
(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['a', 'b'], false, 'Don''t Play'),
(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['a', 'b'], false, 'Play'),
(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['a', 'b'], false, 'Play'),
(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['a', 'd'], true, 'Play'),
(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['c', 'b'], NULL, 'Play'),
(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'),
(15, NULL, 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'),
(16, 'overcast', NULL, 75, ARRAY[81, 75], ARRAY['a', 'd'], false, 'Play'),
(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['c', 'b'], true, 'Don''t Play');
-- no grouping
DROP TABLE IF EXISTS train_output, train_output_summary;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'humidity, windy, "Cont_features"'::text, -- features
NULL::text, -- exclude columns
'gini'::text, -- split criterion
NULL::text, -- no grouping
NULL::text, -- no weights
10::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
8::integer, -- number of bins per continuous variable
'cp=0.01' -- cost-complexity pruning parameter
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
-------------------------------------------------------------------------
-- grouping
DROP TABLE IF EXISTS train_output, train_output_summary, predict_output;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'"OUTLOOK", humidity, windy, cat_features'::text, -- features
NULL::text, -- exclude columns
'gini'::text, -- split criterion
'class'::text, -- grouping
NULL::text, -- no weights
10::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
3::integer, -- number of bins per continuous variable
'cp=0.01' -- cost-complexity pruning parameter
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
-- testing tree_predict with a category not present in training table
CREATE TABLE dt_golf2 as
SELECT * FROM dt_golf
UNION
SELECT 15 as id, 'humid' as "OUTLOOK", 71 as temperature, 80 as humidity,
ARRAY[90, 90] as "Cont_features", ARRAY['b', 'c'] as cat_features,
true as windy, 'Don''t Play' as class;
\x off
SELECT * FROM dt_golf2;
SELECT tree_predict('train_output', 'dt_golf2', 'predict_output');
SELECT *
FROM
predict_output
JOIN
dt_golf2
USING (id);
\x on
select * from train_output;
select * from train_output_summary;
-------------------------------------------------------------------------
-- no grouping, cross-validation
DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv, predict_output;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'"OUTLOOK", cat_features, "Cont_features"'::text, -- features
NULL::text, -- exclude columns
'mse'::text, -- split criterion
NULL::text, -- no grouping
NULL::text, -- no weights
NULL::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
8::integer, -- number of bins per continuous variable
'cp=0.01, n_folds=3',
'max_surrogates=2, null_as_category=True'
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
\x off
SELECT *
FROM
predict_output
JOIN
dt_golf
USING (id);
\x on
select * from train_output;
select * from train_output_summary;
select * from train_output_cv;
-------------------------------------------------------------------------
drop table if exists group_cp;
create table group_cp(class TEXT,
explore_value DOUBLE PRECISION);
insert into group_cp values
('Don''t Play', 0.5),
('Play', -0.1);
-- Test if __build_tree works with an array of cp values
-- this should be uncommented after __build_tree has been changed to accept
-- an array of cp values instead of a single cp value
drop table if exists train_output, train_output_summary;
select __build_tree(
FALSE,
'mse',
'dt_golf',
'train_output',
'id',
'temperature::double precision',
FALSE,
'"OUTLOOK"',
ARRAY['"OUTLOOK"']::text[],
'{}',
'{}',
'{humidity}',
'class',
'1',
4,
2,
1,
3,
'group_cp',
0::smallint,
'notice',
NULL,
0
);
select tree_display('train_output', FALSE);
\d train_output_summary
\x on
select * from train_output;
select * from train_output_summary;
-------------------------------------------------------------------------
CREATE TABLE array_test AS
SELECT i as id, array_fill(array_of_float(100), random() * i) as feat, i as dep
FROM generate_series(1, 10) as i;
DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv, predict_output;
SELECT tree_train('array_test'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'dep::double precision'::text, -- response
'feat'::text, -- features
NULL::text, -- exclude columns
'mse'::text, -- split criterion
NULL::text, -- no grouping
NULL::text, -- no weights
2::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
8::integer -- number of bins per continuous variable
);
SELECT * FROM train_output_summary;
SELECT _print_decision_tree(tree) FROM train_output;
SELECT tree_display('train_output', False);
SELECT tree_predict('train_output', 'array_test', 'predict_output');