| --------------------------------------------------------------------------- |
| -- training |
| DROP TABLE IF EXISTS dummy_dt_con_src CASCADE; |
| CREATE TABLE dummy_dt_con_src ( |
| id integer, |
| cat integer[], |
| con float8[], |
| y float8 |
| ); |
| |
| INSERT INTO dummy_dt_con_src VALUES |
| (1, '{0}'::integer[], ARRAY[0], 0.5), |
| (2, '{0}'::integer[], ARRAY[1], 0.5), |
| (3, '{0}'::integer[], ARRAY[4], 0.5), |
| (4, '{0}'::integer[], ARRAY[5], 0.1), |
| (5, '{0}'::integer[], ARRAY[6], 0.1), |
| (6, '{1}'::integer[], ARRAY[9], 0.1); |
| |
| ------------------------------------------------------------ |
| -- entropy |
| SELECT |
| assert(relative_error(_dst_compute_entropy(i, 10), 1.0) < 1e-6, |
| 'expect 1 bit entropy, got ' || _dst_compute_entropy(i, 10)::text) |
| FROM |
| (SELECT generate_series(0, 1) AS i) seq |
| ; |
| |
| ------------------------------------------------------------ |
| -- con_splits |
| DROP TABLE IF EXISTS dummy_splits; |
| CREATE TABLE dummy_splits AS |
| SELECT _dst_compute_con_splits( |
| con, |
| 9, |
| 5::smallint |
| ) as splits |
| FROM dummy_dt_con_src |
| ; |
| |
| SELECT * FROM dummy_splits; |
| |
| --------------------------------------------------------------------------- |
| -- cat encoding |
| SELECT |
| assert( |
| relative_error(_map_catlevel_to_int('{B}', '{A,B}', ARRAY[2], TRUE), ARRAY[1]) < 1e-6, |
| 'wrong results in _map_catlevel_to_int()') |
| ; |
| |
| ------------------------------------------------------------ |
| -- test training aggregate manually |
| \x on |
| SELECT _print_decision_tree(_initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint)); |
| |
| ------------------------------------------------------------------------- |
| -- regression tree for single level |
| SELECT |
| * |
| -- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'), |
| -- assert(tree_depth = 1, 'dummy_dt (wrong tree_depth)'), |
| -- assert(relative_error(feature_indices, ARRAY[0,-1,-1]) = 0., |
| -- 'dummy_dt (wrong feature_indices)'), |
| -- assert(feature_thresholds[1] = 4, 'dummy_dt (wrong feature_thresholds)'), |
| -- assert(is_categorical[1] = 0, 'dummy_dt (wrong is_categorical)'), |
| -- assert(relative_error(predictions[1:1][2:3], ARRAY[0.5, 0.1]) < 1e-6, |
| -- 'dummy_dt (wrong predictions)') |
| FROM ( |
| SELECT (tree).* |
| FROM ( |
| SELECT |
| _print_decision_tree( |
| (_dt_apply( |
| _initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint), |
| _compute_leaf_stats( |
| _initialize_decision_tree(TRUE, 'mse', 1::smallint, 5::smallint), |
| cat::integer[], |
| con::double precision[], |
| y::double precision, |
| 1.0::double precision, |
| '{2}'::integer[], |
| (SELECT splits FROM dummy_splits)::BYTEA8, |
| 2::smallint, |
| FALSE), |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| 1::smallint, |
| 10::smallint, |
| False::boolean, |
| 1::integer |
| )).tree_state |
| ) as tree |
| FROM dummy_dt_con_src |
| ) q1 |
| ) q2 |
| ; |
| |
| ------------------------------------------------------------------------- |
| -- classification tree for multi-levels |
| DROP TABLE IF EXISTS dummy_dt_cat_src CASCADE; |
| CREATE TABLE dummy_dt_cat_src ( |
| id integer, |
| cat integer[], |
| con float8[], |
| y integer, |
| weight float8 |
| ); |
| |
| INSERT INTO dummy_dt_cat_src VALUES |
| (1, '{0, 0}'::integer[], ARRAY[100, 0], 0, 10), |
| (2, '{0, 0}'::integer[], ARRAY[200, 1], 0, 10), |
| (3, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 10), |
| (4, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1), |
| (5, '{0, 1}'::integer[], ARRAY[1000, 4], 0, 1), |
| (6, '{0, 1}'::integer[], ARRAY[10, 5], 1, 1), |
| (7, '{0, 2}'::integer[], ARRAY[10.10, 10], 1, 1), |
| (8, '{1, 2}'::integer[], ARRAY[12, 9], 1, 1), |
| (8, '{1, 2}'::integer[], ARRAY[13.4, 9], 1, 1), |
| (8, '{1, 2}'::integer[], ARRAY[15, 9], 1, 1), |
| (9, '{1, 2}'::integer[], ARRAY[15.4, 9], 1, 0.9), |
| (10, '{1, 2}'::integer[], ARRAY[5.4, 10], 1, 0.9), |
| (11, '{1, 2}'::integer[], ARRAY[15.4, 10], 1, 0.9), |
| (11, '{1, 2}'::integer[], ARRAY[25.4, 1], 1, 0.9), |
| (12, '{1, 2}'::integer[], ARRAY[15.4, -20], 1, 0.9), |
| (13, '{1, 2}'::integer[], ARRAY[15.4, 200], 1, 0.9), |
| (14, '{1, 2}'::integer[], ARRAY[15.4, 0], 1, 0.9), |
| (15, '{1, 2}'::integer[], ARRAY[11.4, 1], 1, 0.9), |
| (16, '{1, 2}'::integer[], ARRAY[124, 2], 1, 0.9); |
| |
| DROP TABLE IF EXISTS dummy_splits; |
| CREATE TABLE dummy_splits AS |
| SELECT _dst_compute_con_splits( |
| con, |
| 9, |
| 3::smallint |
| ) as splits |
| FROM dummy_dt_cat_src |
| ; |
| |
| -- Add first iteration of tree in an output table |
| create table dummy_dt_output as |
| SELECT |
| 1::integer as iteration, |
| (_dt_apply( |
| _initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint), |
| _compute_leaf_stats( |
| _initialize_decision_tree(FALSE, 'entropy', 2::smallint, 5::smallint), |
| cat, |
| con, |
| y, |
| weight::float8, |
| '{2, 3}'::integer[], |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| FALSE), |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| 0::smallint, |
| 3::smallint, |
| False::boolean, |
| 1::integer |
| )).tree_state as tree |
| FROM dummy_dt_cat_src |
| ; |
| |
| SELECT * FROM dummy_dt_output; |
| |
| -- Train tree further |
| insert into dummy_dt_output |
| SELECT |
| 2::integer as iteration, |
| (_dt_apply( |
| (SELECT tree from dummy_dt_output where iteration = 1), |
| _compute_leaf_stats( |
| (SELECT tree from dummy_dt_output where iteration = 1), |
| cat, |
| con, |
| y, |
| 1.0::float8, |
| '{2, 3}'::integer[], |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| FALSE), |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| 1::smallint, |
| 3::smallint, |
| False::boolean, |
| 1::integer |
| )).tree_state as tree |
| FROM dummy_dt_cat_src |
| ; |
| |
| -- Final iteration to wrap the training of tree |
| -- (this iteration ends training but does not expand tree further) |
| SELECT |
| * |
| -- , assert(tree_depth IS NOT NULL, 'dummy_dt (empty tree returned)'), |
| -- assert(tree_depth = 2, 'dummy_dt (wrong tree_depth)'), |
| -- assert(relative_error(feature_indices, ARRAY[1,0,-2,-2,-2,-3,-3]) = 0., |
| -- 'dummy_dt (wrong feature_indices)'), |
| -- assert(relative_error(feature_thresholds, ARRAY[1,12,0,0,0,0,0]) = 0., |
| -- 'dummy_dt (wrong feature_thresholds)'), |
| -- assert(relative_error(is_categorical, ARRAY[1,0,0,0,0,0,0]) = 0., |
| -- 'dummy_dt (wrong is_categorical)') |
| FROM ( |
| SELECT (tree).* |
| FROM ( |
| SELECT |
| _print_decision_tree( |
| (_dt_apply( |
| (SELECT tree from dummy_dt_output where iteration = 2), |
| _compute_leaf_stats( |
| (SELECT tree from dummy_dt_output where iteration = 2), |
| cat, |
| con, |
| y, |
| 1.0::float8, |
| '{2, 3}'::integer[], |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| FALSE), |
| (SELECT splits FROM dummy_splits), |
| 2::smallint, |
| 1::smallint, |
| 3::smallint, |
| False::boolean, |
| 1::integer |
| )).tree_state |
| ) as tree |
| FROM dummy_dt_cat_src |
| ) q1 |
| ) q2 |
| ; |
| -------------------------------------------------------------------------------- |
| -- Validate tree train function ------------------------------------------------ |
| DROP TABLE IF EXISTS dt_golf CASCADE; |
| CREATE TABLE dt_golf ( |
| id integer NOT NULL, |
| "OUTLOOK" text, |
| temperature double precision, |
| humidity double precision, |
| "Cont_features" double precision[], |
| cat_features text[], |
| windy boolean, |
| class text |
| ) ; |
| |
| INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,"Cont_features",cat_features, windy,class) VALUES |
| (1, 'sunny', 85, 85,ARRAY[85, 85], ARRAY['a', 'b'], false, 'Don''t Play'), |
| (2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['a', 'b'], true, 'Don''t Play'), |
| (3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['a', 'b'], false, 'Play'), |
| (4, 'rain', 70, NULL, ARRAY[70, 96], ARRAY['a', 'b'], false, 'Play'), |
| (5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['a', 'b'], false, 'Play'), |
| (6, 'rain', NULL, 70, ARRAY[65, 70], ARRAY['a', 'b'], true, 'Don''t Play'), |
| (7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['c', 'b'], NULL , 'Play'), |
| (8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['a', 'b'], false, 'Don''t Play'), |
| (9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['a', 'b'], false, 'Play'), |
| (10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['a', 'b'], false, 'Play'), |
| (11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['a', 'd'], true, 'Play'), |
| (12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['c', 'b'], NULL, 'Play'), |
| (13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'), |
| (15, NULL, 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'), |
| (16, 'overcast', NULL, 75, ARRAY[81, 75], ARRAY['a', 'd'], false, 'Play'), |
| (14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['c', 'b'], true, 'Don''t Play'); |
| |
| -- no grouping |
| DROP TABLE IF EXISTS train_output, train_output_summary; |
| SELECT tree_train('dt_golf'::text, -- source table |
| 'train_output'::text, -- output model table |
| 'id'::text, -- id column |
| 'temperature::double precision'::text, -- response |
| 'humidity, windy, "Cont_features"'::text, -- features |
| NULL::text, -- exclude columns |
| 'gini'::text, -- split criterion |
| NULL::text, -- no grouping |
| NULL::text, -- no weights |
| 10::integer, -- max depth |
| 6::integer, -- min split |
| 2::integer, -- min bucket |
| 8::integer, -- number of bins per continuous variable |
| 'cp=0.01' -- cost-complexity pruning parameter |
| ); |
| |
| SELECT _print_decision_tree(tree) from train_output; |
| SELECT tree_display('train_output', False); |
| ------------------------------------------------------------------------- |
| |
| -- grouping |
| DROP TABLE IF EXISTS train_output, train_output_summary, predict_output; |
| SELECT tree_train('dt_golf'::text, -- source table |
| 'train_output'::text, -- output model table |
| 'id'::text, -- id column |
| 'temperature::double precision'::text, -- response |
| '"OUTLOOK", humidity, windy, cat_features'::text, -- features |
| NULL::text, -- exclude columns |
| 'gini'::text, -- split criterion |
| 'class'::text, -- grouping |
| NULL::text, -- no weights |
| 10::integer, -- max depth |
| 6::integer, -- min split |
| 2::integer, -- min bucket |
| 3::integer, -- number of bins per continuous variable |
| 'cp=0.01' -- cost-complexity pruning parameter |
| ); |
| |
| SELECT _print_decision_tree(tree) from train_output; |
| SELECT tree_display('train_output', False); |
| |
| |
| -- testing tree_predict with a category not present in training table |
| CREATE TABLE dt_golf2 as |
| SELECT * FROM dt_golf |
| UNION |
| SELECT 15 as id, 'humid' as "OUTLOOK", 71 as temperature, 80 as humidity, |
| ARRAY[90, 90] as "Cont_features", ARRAY['b', 'c'] as cat_features, |
| true as windy, 'Don''t Play' as class; |
| \x off |
| SELECT * FROM dt_golf2; |
| SELECT tree_predict('train_output', 'dt_golf2', 'predict_output'); |
| SELECT * |
| FROM |
| predict_output |
| JOIN |
| dt_golf2 |
| USING (id); |
| \x on |
| select * from train_output; |
| select * from train_output_summary; |
| ------------------------------------------------------------------------- |
| |
| -- no grouping, cross-validation |
| DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv, predict_output; |
| SELECT tree_train('dt_golf'::text, -- source table |
| 'train_output'::text, -- output model table |
| 'id'::text, -- id column |
| 'temperature::double precision'::text, -- response |
| '"OUTLOOK", cat_features, "Cont_features"'::text, -- features |
| NULL::text, -- exclude columns |
| 'mse'::text, -- split criterion |
| NULL::text, -- no grouping |
| NULL::text, -- no weights |
| NULL::integer, -- max depth |
| 6::integer, -- min split |
| 2::integer, -- min bucket |
| 8::integer, -- number of bins per continuous variable |
| 'cp=0.01, n_folds=3', |
| 'max_surrogates=2, null_as_category=True' |
| ); |
| |
| SELECT _print_decision_tree(tree) from train_output; |
| SELECT tree_display('train_output', False); |
| SELECT tree_predict('train_output', 'dt_golf', 'predict_output'); |
| \x off |
| SELECT * |
| FROM |
| predict_output |
| JOIN |
| dt_golf |
| USING (id); |
| \x on |
| select * from train_output; |
| select * from train_output_summary; |
| select * from train_output_cv; |
| ------------------------------------------------------------------------- |
| |
| drop table if exists group_cp; |
| create table group_cp(class TEXT, |
| explore_value DOUBLE PRECISION); |
| insert into group_cp values |
| ('Don''t Play', 0.5), |
| ('Play', -0.1); |
| |
| -- Test if __build_tree works with an array of cp values |
| -- this should be uncommented after __build_tree has been changed to accept |
| -- an array of cp values instead of a single cp value |
| drop table if exists train_output, train_output_summary; |
| select __build_tree( |
| FALSE, |
| 'mse', |
| 'dt_golf', |
| 'train_output', |
| 'id', |
| 'temperature::double precision', |
| FALSE, |
| '"OUTLOOK"', |
| ARRAY['"OUTLOOK"']::text[], |
| '{}', |
| '{}', |
| '{humidity}', |
| 'class', |
| '1', |
| 4, |
| 2, |
| 1, |
| 3, |
| 'group_cp', |
| 0::smallint, |
| 'notice', |
| NULL, |
| 0 |
| ); |
| |
| select tree_display('train_output', FALSE); |
| |
| \d train_output_summary |
| \x on |
| select * from train_output; |
| select * from train_output_summary; |
| |
| ------------------------------------------------------------------------- |
| CREATE TABLE array_test AS |
| SELECT i as id, array_fill(array_of_float(100), random() * i) as feat, i as dep |
| FROM generate_series(1, 10) as i; |
| |
| |
| DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv, predict_output; |
| SELECT tree_train('array_test'::text, -- source table |
| 'train_output'::text, -- output model table |
| 'id'::text, -- id column |
| 'dep::double precision'::text, -- response |
| 'feat'::text, -- features |
| NULL::text, -- exclude columns |
| 'mse'::text, -- split criterion |
| NULL::text, -- no grouping |
| NULL::text, -- no weights |
| 2::integer, -- max depth |
| 6::integer, -- min split |
| 2::integer, -- min bucket |
| 8::integer -- number of bins per continuous variable |
| ); |
| SELECT * FROM train_output_summary; |
| SELECT _print_decision_tree(tree) FROM train_output; |
| SELECT tree_display('train_output', False); |
| SELECT tree_predict('train_output', 'array_test', 'predict_output'); |