blob: 28a46475fe8f9920c458617f6db3a28b412ca613 [file] [log] [blame]
---------------------------------------------------------------------------
-- training
DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
CREATE TABLE dummy_dt_con_src (
id integer,
cat integer[],
con float8[],
y float8
);
INSERT INTO dummy_dt_con_src VALUES
(1, '{0}'::integer[], ARRAY[0], 0.5),
(2, '{0}'::integer[], ARRAY[1], 0.5),
(3, '{0}'::integer[], ARRAY[4], 0.5),
(4, '{0}'::integer[], ARRAY[5], 0.1),
(5, '{0}'::integer[], ARRAY[6], 0.1),
(6, '{1}'::integer[], ARRAY[9], 0.1);
------------------------------------------------------------
-- entropy
SELECT
assert(relative_error(_dst_compute_entropy(i, 10), 1.0) < 1e-6,
'expect 1 bit entropy, got ' || _dst_compute_entropy(i, 10)::text)
FROM
(SELECT generate_series(0, 1) AS i) seq
;
------------------------------------------------------------
-- con_splits
DROP TABLE IF EXISTS dummy_splits;
CREATE TABLE dummy_splits AS
SELECT _dst_compute_con_splits(
con,
9,
5::smallint
) as splits
FROM dummy_dt_con_src
;
SELECT * FROM dummy_splits;
---------------------------------------------------------------------------
-- cat encoding
SELECT
assert(
relative_error(_map_catlevel_to_int('{B}', '{A,B}', ARRAY[2]), ARRAY[1]) < 1e-6,
'wrong results in _map_catlevel_to_int()')
;
------------------------------------------------------------
-- test training aggregate manually
\x on
SELECT _print_decision_tree(_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint));
-------------------------------------------------------------------------
-- regression tree for single level
SELECT
*
-- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'),
-- assert(tree_depth = 1, 'dummy_dt (wrong tree_depth)'),
-- assert(relative_error(feature_indices, ARRAY[0,-1,-1]) = 0.,
-- 'dummy_dt (wrong feature_indices)'),
-- assert(feature_thresholds[1] = 4, 'dummy_dt (wrong feature_thresholds)'),
-- assert(is_categorical[1] = 0, 'dummy_dt (wrong is_categorical)'),
-- assert(relative_error(predictions[1:1][2:3], ARRAY[0.5, 0.1]) < 1e-6,
-- 'dummy_dt (wrong predictions)')
FROM (
SELECT (tree).*
FROM (
SELECT
_print_decision_tree(
(_dt_apply(
_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint),
_compute_leaf_stats(
_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint),
cat::integer[],
con::double precision[],
y::double precision,
1.0::double precision,
'{2}'::integer[],
(SELECT splits FROM dummy_splits)::BYTEA8,
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
10::smallint,
False::boolean,
1::integer
)).tree_state
) as tree
FROM dummy_dt_con_src
) q1
) q2
;
-------------------------------------------------------------------------
-- classification tree for multi-levels
DROP TABLE IF EXISTS dummy_dt_cat_src CASCADE;
CREATE TABLE dummy_dt_cat_src (
id integer,
cat integer[],
con float8[],
y integer,
weight float8
);
INSERT INTO dummy_dt_cat_src VALUES
(1, '{0, 0}'::integer[], ARRAY[100, 0], 0, 10),
(2, '{0, 0}'::integer[], ARRAY[200, 1], 0, 10),
(3, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 10),
(4, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1),
(5, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1),
(6, '{0, 1}'::integer[], ARRAY[10, 5], 1, 1),
(7, '{0, 2}'::integer[], ARRAY[10.10, 10], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[12, 9], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[13.4, 9], 1, 1),
(8, '{1, 2}'::integer[], ARRAY[15, 9], 1, 1),
(9, '{1, 2}'::integer[], ARRAY[15.4, 9], 1, 0.9),
(10, '{1, 2}'::integer[], ARRAY[5.4, 10], 1, 0.9),
(11, '{1, 2}'::integer[], ARRAY[15.4, 10], 1, 0.9),
(11, '{1, 2}'::integer[], ARRAY[25.4, 1], 1, 0.9),
(12, '{1, 2}'::integer[], ARRAY[15.4, -20], 1, 0.9),
(13, '{1, 2}'::integer[], ARRAY[15.4, 200], 1, 0.9),
(14, '{1, 2}'::integer[], ARRAY[15.4, 0], 1, 0.9),
(15, '{1, 2}'::integer[], ARRAY[11.4, 1], 1, 0.9),
(16, '{1, 2}'::integer[], ARRAY[124, 2], 1, 0.9);
DROP TABLE IF EXISTS dummy_splits;
CREATE TABLE dummy_splits AS
SELECT _dst_compute_con_splits(
con,
9,
3::smallint
) as splits
FROM dummy_dt_cat_src
;
-- Add first iteration of tree in an output table
create table dummy_dt_output as
SELECT
1::integer as iteration,
(_dt_apply(
_initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint),
_compute_leaf_stats(
_initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint),
cat,
con,
y,
weight::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
0::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state as tree
FROM dummy_dt_cat_src
;
SELECT * FROM dummy_dt_output;
-- Train tree further
insert into dummy_dt_output
SELECT
2::integer as iteration,
(_dt_apply(
(SELECT tree from dummy_dt_output where iteration = 1),
_compute_leaf_stats(
(SELECT tree from dummy_dt_output where iteration = 1),
cat,
con,
y,
1.0::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state as tree
FROM dummy_dt_cat_src
;
-- Final iteration to wrap the training of tree
-- (this iteration ends training but does not expand tree further)
SELECT
*
-- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'),
-- assert(tree_depth = 2, 'dummy_dt (wrong tree_depth)'),
-- assert(relative_error(feature_indices, ARRAY[1,0,-2,-2,-2,-3,-3]) = 0.,
-- 'dummy_dt (wrong feature_indices)'),
-- assert(relative_error(feature_thresholds, ARRAY[1,12,0,0,0,0,0]) = 0.,
-- 'dummy_dt (wrong feature_thresholds)'),
-- assert(relative_error(is_categorical, ARRAY[1,0,0,0,0,0,0]) = 0.,
-- 'dummy_dt (wrong is_categorical)')
FROM (
SELECT (tree).*
FROM (
SELECT
_print_decision_tree(
(_dt_apply(
(SELECT tree from dummy_dt_output where iteration = 2),
_compute_leaf_stats(
(SELECT tree from dummy_dt_output where iteration = 2),
cat,
con,
y,
1.0::float8,
'{2, 3}'::integer[],
(SELECT splits FROM dummy_splits),
2::smallint,
FALSE),
(SELECT splits FROM dummy_splits),
2::smallint,
1::smallint,
3::smallint,
False::boolean,
1::integer
)).tree_state
) as tree
FROM dummy_dt_cat_src
) q1
) q2
;
--------------------------------------------------------------------------------
-- Validate tree train function ------------------------------------------------
DROP TABLE IF EXISTS dt_golf CASCADE;
CREATE TABLE dt_golf (
id integer NOT NULL,
"OUTLOOK" text,
temperature double precision,
humidity double precision,
windy boolean,
class text
) ;
INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,windy,class) VALUES
(1, 'sunny', 85, 85, false, 'Don''t Play'),
(2, 'sunny', 80, 90, true, 'Don''t Play'),
(3, 'overcast', 83, 78, false, 'Play'),
(4, 'rain', 70, 96, false, 'Play'),
(5, 'rain', 68, 80, false, 'Play'),
(6, 'rain', 65, 70, true, 'Don''t Play'),
(7, 'overcast', 64, 65, true, 'Play'),
(8, 'sunny', 72, 95, false, 'Don''t Play'),
(9, 'sunny', 69, 70, false, 'Play'),
(10, 'rain', 75, 80, false, 'Play'),
(11, 'sunny', 75, 70, true, 'Play'),
(12, 'overcast', 72, 90, true, 'Play'),
(13, 'overcast', 81, 75, false, 'Play'),
(14, 'rain', 71, 80, true, 'Don''t Play');
-- no grouping
DROP TABLE IF EXISTS train_output, train_output_summary;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'humidity, windy'::text, -- features
NULL::text, -- exclude columns
'gini'::text, -- split criterion
NULL::text, -- no grouping
NULL::text, -- no weights
10::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
8::integer, -- number of bins per continuous variable
'cp=0.01' -- cost-complexity pruning parameter
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
-------------------------------------------------------------------------
-- grouping
DROP TABLE IF EXISTS train_output, train_output_summary, predict_output;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'humidity, windy'::text, -- features
NULL::text, -- exclude columns
'gini'::text, -- split criterion
'class'::text, -- grouping
NULL::text, -- no weights
10::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
3::integer, -- number of bins per continuous variable
'cp=0.01' -- cost-complexity pruning parameter
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
\x off
SELECT *
FROM
predict_output
JOIN
dt_golf
USING (id);
\x on
select * from train_output;
select * from train_output_summary;
-------------------------------------------------------------------------
-- no grouping, cross-validation
DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv, predict_output;
SELECT tree_train('dt_golf'::text, -- source table
'train_output'::text, -- output model table
'id'::text, -- id column
'temperature::double precision'::text, -- response
'humidity, windy'::text, -- features
NULL::text, -- exclude columns
'mse'::text, -- split criterion
NULL::text, -- no grouping
NULL::text, -- no weights
NULL::integer, -- max depth
6::integer, -- min split
2::integer, -- min bucket
8::integer, -- number of bins per continuous variable
'cp=0.01, n_folds=5',
'max_surrogates=2'
);
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
\x off
SELECT *
FROM
predict_output
JOIN
dt_golf
USING (id);
\x on
select * from train_output;
select * from train_output_summary;
select * from train_output_cv;
-------------------------------------------------------------------------
drop table if exists group_cp;
create table group_cp(class TEXT,
explore_value DOUBLE PRECISION);
insert into group_cp values
('Don''t Play', 0.5),
('Play', -0.1);
-- Test if __build_tree works with an array of cp values
-- this should be uncommented after __build_tree has been changed to accept
-- an array of cp values instead of a single cp value
drop table if exists train_output, train_output_summary;
select __build_tree(
FALSE,
'mse',
'dt_golf',
'train_output',
'id',
'temperature::double precision',
FALSE,
ARRAY['"OUTLOOK"']::text[],
'{}',
'{}',
'{humidity}',
'class',
'1',
4,
2,
1,
5,
'group_cp',
0::smallint,
'notice',
0
);
select tree_display('train_output', FALSE);
\d train_output_summary
\x on
select * from train_output;
select * from train_output_summary;