| /* ----------------------------------------------------------------------- *//** |
| * |
| * @file decision_tree.sql_in |
| * |
| * @brief SQL implementation of a decision tree |
| * @date January 2011 |
| * |
| *//* ----------------------------------------------------------------------- */ |
| |
| /** |
| @addtogroup grp_dectree |
| |
| @about |
| |
| This module provides an implementation of the C4.5 decision tree algorithm, |
| but with several modifications. It assumes that: |
| - The data set is very large. |
| - It is sparse (that's why it uses sparse vector data type for storage). |
| - Data features come from a discrete set of variables, example: {1,2,3,4,5}. |
| |
| Because we assume very large amount of sparse data as input the following |
| additional steps have been implemented: |
| - The algorithm starts by eliminating all redundant points, |
| by producing a smaller subset of unique, weighted points. |
| - Further assuming a very large number of features at each step |
| algorithm test only a subset of features (instead of all possible features) |
| to find the best split criteria. |
| |
| @input |
| |
| The <b>training data</b> is expected to be of the following form: |
| <pre>{TABLE|VIEW} <em>trainingSource</em> ( |
| ... |
| <em>id</em> INTEGER, |
| <em>features</em> SVEC, |
| <em>class</em> INTEGER, |
| ... |
| )</pre> |
| |
| The <b>data to classify</b> is expected to be of the following form: |
| <pre>{TABLE|VIEW} <em>classifySource</em> ( |
| ... |
| <em>id</em> INTEGER, |
| <em>features</em> SVEC, |
| ... |
| )</pre> |
| |
| @usage |
| |
| - Run the training algorithm on the source data: |
| <pre>SELECT * FROM \ref train_tree( |
| '<em>table_input</em>', '<em>id_col_name</em>', '<em>feature_col_name</em>', |
| '<em>class_col_name</em>', '<em>num_values</em>', '<em>max_num_iter</em>' |
| );</pre> |
| This will create the MADLIB_SCHEMA.tree table storing an abstract object |
| (representing the model) used for further classification. Column names: |
| <pre> id | tree_location | hash | feature | probability | chisq | maxclass | infogain | live | cat_size | parent_id | jump |
| ----+---------------+------+---------+-------------+-------+----------+----------+------+----------+-----------+------ |
| ...</pre> |
| |
| - Run the classification function using the learned model: |
| <pre>SELECT * FROM \ref classify_tree( |
| '<em>table_name</em>', '<em>id_col_name</em>', '<em>feature_col_name</em>', |
| '<em>num_values</em>');</pre> |
| This will create the MADLIB_SCHEMA.classified_points table with the |
| classification results. Column names: |
| <pre> id | feature | jump | class | prob |
| ----+---------+------+-------+------ |
| ...</pre> |
| |
| @examp |
| |
| -# Prepare an input table/view, e.g.: |
| \verbatim |
| sql> SELECT * FROM data.training LIMIT 10; |
| id | feature | class |
| ----+-----------------------------+------- |
| 1 | {1,1,1,1,1,5}:{1,0,1,0,1,0} | 1 |
| 3 | {1,1,1,1,1,5}:{1,0,2,0,2,0} | 3 |
| 5 | {1,1,1,1,1,5}:{2,0,2,0,2,0} | 5 |
| 7 | {1,1,1,1,1,5}:{1,0,1,0,2,0} | 2 |
| 9 | {1,1,1,1,1,5}:{2,0,2,0,1,0} | 4 |
| 11 | {1,1,1,1,1,5}:{1,0,1,0,1,0} | 1 |
| 13 | {1,1,1,1,1,5}:{1,0,2,0,2,0} | 3 |
| 15 | {1,1,1,1,1,5}:{2,0,2,0,2,0} | 5 |
| 17 | {1,1,1,1,1,5}:{1,0,1,0,2,0} | 2 |
| 19 | {1,1,1,1,1,5}:{2,0,2,0,1,0} | 4 |
| (10 rows) |
| \endverbatim |
| -# Train the decision tree model, e.g.: |
| \verbatim |
| sql> SELECT * FROM train_tree( 'data.training', 'id', 'feature', 'class', 5, 5000); |
| INFO: selected_dimensions = {1,2,3,4,5,6,7,8,9,10} , sample dimentions = 10 |
| CONTEXT: PL/pgSQL function "train_tree" line 71 at assignment |
| SQL statement "SELECT mad.Train_Tree( $1 , $2 , $3 , $4 , $5 , $6 , 0)" |
| ... |
| train_tree |
| ------------ |
| |
| (1 row) |
| \endverbatim |
| -# Check few rows from the tree model table: |
| \verbatim |
| sql> SELECT * FROM MADLIB_SCHEMA.tree ORDER BY id LIMIT 10; |
| id | tree_location | hash | feature | probability | chisq | maxclass | infogain | live | cat_size | parent_id | jump |
| ----+---------------+------------+---------+-------------------+-----------------------+----------+-------------------+------+----------+-----------+--------------- |
| 1 | {0} | 12459841 | 1 | 0.2 | 0 | 1 | 0.554517744447956 | 0 | 5000 | 0 | [2:3]={2,3} |
| 2 | {0,1} | 2106753344 | 5 | 0.4 | 0 | 1 | 0.395752794785278 | 0 | 2500 | 1 | [2:3]={4,5} |
| 3 | {0,2} | 2106753345 | 5 | 0.4 | 0 | 4 | 0.395752794785278 | 0 | 2500 | 1 | [2:3]={6,7} |
| 4 | {0,1,1} | -856809791 | 3 | 0.666666666666667 | 0 | 1 | 0.636514168294813 | 0 | 1500 | 2 | [2:3]={8,9} |
| 5 | {0,1,2} | -856809790 | 3 | 0.5 | 1.78114410168526e-215 | 2 | 0.693147180559945 | 0 | 1000 | 2 | [2:3]={10,11} |
| 6 | {0,2,1} | -856744128 | 3 | 0.5 | 1.78114410168526e-215 | 3 | 0.693147180559945 | 0 | 1000 | 3 | [2:3]={12,13} |
| 7 | {0,2,2} | -856744127 | 3 | 0.666666666666667 | 0 | 5 | 0.636514168294813 | 0 | 1500 | 3 | [2:3]={14,15} |
| 8 | {0,1,1,1} | -924696128 | 1 | 1 | 1 | 1 | 0 | 0 | 1000 | 4 | |
| 9 | {0,1,1,2} | -924696127 | 1 | 1 | 1 | 2 | 0 | 0 | 500 | 4 | |
| 10 | {0,1,2,1} | -924630465 | 1 | 1 | 1 | 2 | 0 | 0 | 500 | 5 | |
| (10 rows) |
| \endverbatim |
| -# To classify the new data points run: |
| \verbatim |
| sql> SELECT * FROM classify_tree( 'toclassify', 'id', 'feature', 10); |
| classify_tree |
| --------------- |
| |
| (1 row) |
| \endverbatim |
| -# Check classification results: |
| \verbatim |
| SELECT * FROM MADLIB_SCHEMA.classified_points ORDER BY id LIMIT 4; |
| id | feature | jump | class | prob |
| ----+-----------------------------+------+-------+------------------- |
| 1 | {1,1,1,1,1,5}:{3,3,3,3,3,3} | 0 | 5 | 0.207920792079208 |
| 2 | {1,1,1,1,1,5}:{1,0,1,0,4,4} | 0 | 5 | 0.207920792079208 |
| 3 | {1,1,1,1,1,5}:{1,3,1,0,0,0} | 0 | 5 | 0.207920792079208 |
| 4 | {1,1,1,1,1,5}:{0,0,0,0,1,1} | 0 | 1 | 0.333333333333333 |
| (4 rows) |
| \endverbatim |
| |
| @literature |
| |
| [1] http://en.wikipedia.org/wiki/C4.5_algorithm |
| |
| @sa File decision_tree.sql_in documenting the SQL functions. |
| */ |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.hash_array(INT4[]); |
| CREATE FUNCTION MADLIB_SCHEMA.hash_array(INT4[]) RETURNS INT4 |
| AS 'MODULE_PATHNAME', 'hash_array' |
| LANGUAGE C IMMUTABLE; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.findentropy(int4[], int4[], int4, int4); |
| CREATE FUNCTION MADLIB_SCHEMA.findentropy(int4[], int4[], int4, int4) RETURNS FLOAT8 |
| AS 'MODULE_PATHNAME', 'findentropy' |
| LANGUAGE C IMMUTABLE; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.aggr_InfoGain(FLOAT8[], FLOAT8, FLOAT8, INT4, INT4, INT4); |
| CREATE FUNCTION MADLIB_SCHEMA.aggr_InfoGain(FLOAT8[], FLOAT8, FLOAT8, INT4, INT4, INT4) RETURNS FLOAT8[] |
| AS 'MODULE_PATHNAME', 'aggr_InfoGain' |
| LANGUAGE C IMMUTABLE; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.compute_InfoGain(FLOAT8[], INT4, INT4); |
| CREATE FUNCTION MADLIB_SCHEMA.compute_InfoGain(FLOAT8[], INT4, INT4) RETURNS FLOAT8[] |
| AS 'MODULE_PATHNAME', 'compute_InfoGain' |
| LANGUAGE C IMMUTABLE; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.mallocset(INT4, INT4); |
| CREATE FUNCTION MADLIB_SCHEMA.mallocset(INT4, INT4) RETURNS FLOAT8[] |
| AS 'MODULE_PATHNAME', 'mallocset' |
| LANGUAGE C IMMUTABLE; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.WeightedNoReplacement(INT4, INT4); |
| CREATE FUNCTION MADLIB_SCHEMA.WeightedNoReplacement(INT4, INT4) RETURNS INT8[] |
| AS 'MODULE_PATHNAME', 'WeightedNoReplacement' |
| LANGUAGE C IMMUTABLE; |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.min ( |
| x double precision, |
| y double precision |
| ) RETURNS double precision |
| AS $$ |
| BEGIN |
| IF (X <= Y) THEN |
| RETURN x; |
| ELSE |
| RETURN y; |
| END IF; |
| END; |
| $$ LANGUAGE 'plpgsql'; |
| |
| ---------------------- CreateHash ---------- START |
| |
| DROP TYPE IF EXISTS MADLIB_SCHEMA.hash_val CASCADE; |
| CREATE TYPE MADLIB_SCHEMA.hash_val AS( |
| id INTEGER, |
| feature MADLIB_SCHEMA.svec, |
| class INTEGER, |
| weight INTEGER, |
| selection INTEGER |
| ); |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.weight_count(MADLIB_SCHEMA.hash_val, int, MADLIB_SCHEMA.svec, int) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.weight_count(MADLIB_SCHEMA.hash_val, int, MADLIB_SCHEMA.svec, int) RETURNS MADLIB_SCHEMA.hash_val AS $$ |
| declare |
| begin |
| |
| IF ($1.weight IS NOT NULL) THEN |
| $1.weight = $1.weight + 1; |
| ELSE |
| $1.weight = 1; |
| $1.feature = $3; |
| $1.id = $2; |
| $1.class = $4; |
| $1.selection = 1; |
| END IF; |
| |
| RETURN $1; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.weight_aggr(MADLIB_SCHEMA.hash_val, MADLIB_SCHEMA.hash_val) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.weight_aggr(MADLIB_SCHEMA.hash_val, MADLIB_SCHEMA.hash_val) RETURNS MADLIB_SCHEMA.hash_val AS $$ |
| declare |
| begin |
| |
| IF ($2.id IS NOT NULL) THEN |
| $1.id = $2.id; |
| $1.feature = $2.feature; |
| $1.class = $2.class; |
| $1.weight = COALESCE($1.weight, 0) + COALESCE($2.weight, 0); |
| $1.selection = 1; |
| END IF; |
| |
| RETURN $1; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.CreateHash(int, MADLIB_SCHEMA.svec, int); |
| CREATE AGGREGATE MADLIB_SCHEMA.CreateHash(int, MADLIB_SCHEMA.svec, int) ( |
| SFUNC=MADLIB_SCHEMA.weight_count, |
| m4_ifdef(`GREENPLUM',`prefunc=MADLIB_SCHEMA.weight_aggr,') |
| STYPE=MADLIB_SCHEMA.hash_val |
| ); |
| |
| ---------------------- CreateHash ---------- END |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gammaln (x double precision) |
| RETURNS double precision |
| AS $$ |
| DECLARE |
| y double precision; |
| xnum double precision; |
| xden double precision; |
| xm1 double precision; |
| r double precision; |
| xx double precision; |
| i int; |
| BEGIN |
| xnum := 0.0; |
| xden := 1.0; |
| y := x; |
| if x<=(1E-10) then |
| y := -1.0 * ln(x); |
| elsif x <=0.5 then |
| xnum := xnum * y + 4.945235359296727046734888e0; |
| xnum := xnum * y + 2.018112620856775083915565e2; |
| xnum := xnum * y + 2.290838373831346393026739e3; |
| xnum := xnum * y + 1.131967205903380828685045e4; |
| xnum := xnum * y + 2.855724635671635335736389e4; |
| xnum := xnum * y + 3.848496228443793359990269e4; |
| xnum := xnum * y + 2.637748787624195437963534e4; |
| xnum := xnum * y + 7.225813979700288197698961e3; |
| xden := xden * y + 6.748212550303777196073036e1; |
| xden := xden * y + 1.113332393857199323513008e3; |
| xden := xden * y + 7.738757056935398733233834e3; |
| xden := xden * y + 2.763987074403340708898585e4; |
| xden := xden * y + 5.499310206226157329794414e4; |
| xden := xden * y + 6.161122180066002127833352e4; |
| xden := xden * y + 3.635127591501940507276287e4; |
| xden := xden * y + 8.785536302431013170870835e3; |
| y := -ln(y) + (y * (-5.772156649015328605195174e-1 + y * (xnum/xden))); |
| elsif x <=0.6796875 then |
| xm1 := (x-0.5) - 0.5; |
| xnum := xnum * xm1 + 4.974607845568932035012064e0; |
| xnum := xnum * xm1 + 5.424138599891070494101986e2; |
| xnum := xnum * xm1 + 1.550693864978364947665077e4; |
| xnum := xnum * xm1 + 1.847932904445632425417223e5; |
| xnum := xnum * xm1 + 1.088204769468828767498470e6; |
| xnum := xnum * xm1 + 3.338152967987029735917223e6; |
| xnum := xnum * xm1 + 5.106661678927352456275255e6; |
| xnum := xnum * xm1 + 3.074109054850539556250927e6; |
| xden := xden * xm1 + 1.830328399370592604055942e2; |
| xden := xden * xm1 + 7.765049321445005871323047e3; |
| xden := xden * xm1 + 1.331903827966074194402448e5; |
| xden := xden * xm1 + 1.136705821321969608938755e6; |
| xden := xden * xm1 + 5.267964117437946917577538e6; |
| xden := xden * xm1 + 1.346701454311101692290052e7; |
| xden := xden * xm1 + 1.782736530353274213975932e7; |
| xden := xden * xm1 + 9.533095591844353613395747e6; |
| y := -ln(y) + (xm1 * (4.227843350984671393993777e-1 + xm1 * (xnum/xden))); |
| elsif x <= 1.5 then |
| xm1 := (x-0.5) - 0.5; |
| xnum := xnum * xm1 + 4.945235359296727046734888e0; |
| xnum := xnum * xm1 + 2.018112620856775083915565e2; |
| xnum := xnum * xm1 + 2.290838373831346393026739e3; |
| xnum := xnum * xm1 + 1.131967205903380828685045e4; |
| xnum := xnum * xm1 + 2.855724635671635335736389e4; |
| xnum := xnum * xm1 + 3.848496228443793359990269e4; |
| xnum := xnum * xm1 + 2.637748787624195437963534e4; |
| xnum := xnum * xm1 + 7.225813979700288197698961e3; |
| xden := xden * xm1 + 6.748212550303777196073036e1; |
| xden := xden * xm1 + 1.113332393857199323513008e3; |
| xden := xden * xm1 + 7.738757056935398733233834e3; |
| xden := xden * xm1 + 2.763987074403340708898585e4; |
| xden := xden * xm1 + 5.499310206226157329794414e4; |
| xden := xden * xm1 + 6.161122180066002127833352e4; |
| xden := xden * xm1 + 3.635127591501940507276287e4; |
| xden := xden * xm1 + 8.785536302431013170870835e3; |
| y := xm1 * (-5.772156649015328605195174e-1 + xm1 * (xnum/xden)); |
| elsif x <= 4.0 then |
| xm1 := x - 2.0; |
| xnum := xnum * xm1 + 4.974607845568932035012064e0; |
| xnum := xnum * xm1 + 5.424138599891070494101986e2; |
| xnum := xnum * xm1 + 1.550693864978364947665077e4; |
| xnum := xnum * xm1 + 1.847932904445632425417223e5; |
| xnum := xnum * xm1 + 1.088204769468828767498470e6; |
| xnum := xnum * xm1 + 3.338152967987029735917223e6; |
| xnum := xnum * xm1 + 5.106661678927352456275255e6; |
| xnum := xnum * xm1 + 3.074109054850539556250927e6; |
| xden := xden * xm1 + 1.830328399370592604055942e2; |
| xden := xden * xm1 + 7.765049321445005871323047e3; |
| xden := xden * xm1 + 1.331903827966074194402448e5; |
| xden := xden * xm1 + 1.136705821321969608938755e6; |
| xden := xden * xm1 + 5.267964117437946917577538e6; |
| xden := xden * xm1 + 1.346701454311101692290052e7; |
| xden := xden * xm1 + 1.782736530353274213975932e7; |
| xden := xden * xm1 + 9.533095591844353613395747e6; |
| y := xm1 * (4.227843350984671393993777e-1 + xm1 * (xnum/xden)); |
| elsif x <= 12.0 then |
| xm1 := x - 4.0; |
| xden := -1.0; |
| xnum := xnum * xm1 + 1.474502166059939948905062e4; |
| xnum := xnum * xm1 + 2.426813369486704502836312e6; |
| xnum := xnum * xm1 + 1.214755574045093227939592e8; |
| xnum := xnum * xm1 + 2.663432449630976949898078e9; |
| xnum := xnum * xm1 + 2.940378956634553899906876e10; |
| xnum := xnum * xm1 + 1.702665737765398868392998e11; |
| xnum := xnum * xm1 + 4.926125793377430887588120e11; |
| xnum := xnum * xm1 + 5.606251856223951465078242e11; |
| xden := xden * xm1 + 2.690530175870899333379843e3; |
| xden := xden * xm1 + 6.393885654300092398984238e5; |
| xden := xden * xm1 + 4.135599930241388052042842e7; |
| xden := xden * xm1 + 1.120872109616147941376570e9; |
| xden := xden * xm1 + 1.488613728678813811542398e10; |
| xden := xden * xm1 + 1.016803586272438228077304e11; |
| xden := xden * xm1 + 3.417476345507377132798597e11; |
| xden := xden * xm1 + 4.463158187419713286462081e11; |
| y := 1.791759469228055000094023e0 + xm1 * (xnum/xden); |
| else |
| r := 5.7083835261e-03; |
| xm1 := ln(y); |
| xx := x*x; |
| r := r / xx - 1.910444077728e-03; |
| r := r / xx + 8.4171387781295e-04; |
| r := r / xx - 5.952379913043012e-04; |
| r := r / xx + 7.93650793500350248e-04; |
| r := r / xx - 2.777777777777681622553e-03; |
| r := r / xx + 8.333333333333333331554247e-02; |
| r := r / y; |
| y := r + 0.9189385332046727417803297 - 0.5 * xm1 + y * (xm1-1); |
| end if; |
| return y; |
| END; |
| $$ LANGUAGE 'plpgsql'; |
| |
| -- Returns the Gamma probability density function |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gampdf( |
| X double precision, |
| A double precision, |
| B double precision |
| ) |
| RETURNS double precision |
| AS $$ |
| DECLARE |
| result double precision; |
| BEGIN |
| result = (A - 1) * ln(X) - (X/B) - MADLIB_SCHEMA.gammaln(A) - A * ln(B); |
| IF (result < -690) THEN |
| RETURN 0.0; |
| ELSE |
| RETURN exp((A - 1) * ln(X) - (X/B) - MADLIB_SCHEMA.gammaln(A) - A * ln(B)); |
| END IF; |
| END; |
| $$ LANGUAGE 'plpgsql'; |
| |
| -- Returns Chi-square probability density function |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.chi2pdf ( |
| X double precision, |
| V double precision |
| ) RETURNS double precision |
| AS $$ |
| DECLARE |
| BEGIN |
| IF (X <= 0) THEN |
| RETURN .999999; |
| END IF; |
| RETURN MADLIB_SCHEMA.gampdf(X, V/2.0, 2.0); |
| END; |
| $$ LANGUAGE 'plpgsql'; |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.remove_redundent(table_input TEXT, id_col_name TEXT, id_feature_name TEXT, class_col_name TEXT) RETURNS void AS $$ |
| begin |
| DROP TABLE IF EXISTS weighted_points CASCADE; |
| CREATE TEMP TABLE weighted_points( |
| id INTEGER, |
| feature MADLIB_SCHEMA.svec, |
| class INTEGER, |
| weight INTEGER, |
| selection INTEGER |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (selection)'); |
| |
| DROP TABLE IF EXISTS weighted_points2 CASCADE; |
| CREATE TEMP TABLE weighted_points2( |
| id INTEGER, |
| feature MADLIB_SCHEMA.svec, |
| class INTEGER, |
| weight INTEGER, |
| selection INTEGER |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (selection)'); |
| |
| EXECUTE 'INSERT INTO weighted_points SELECT (MADLIB_SCHEMA.CreateHash(id, feature, class)).* FROM (SELECT '||id_col_name||' as id, '||id_feature_name||' as feature, '||class_col_name||' as class, MADLIB_SCHEMA.svec_hash('||id_feature_name||') as hash FROM '||table_input||') as A GROUP BY A.hash,A.class'; |
| end |
| $$ language plpgsql; |
| |
| --------------------- MADLIB_SCHEMA.MADLIB_SCHEMA.findentropy ---------- START |
| |
| -- FindInfoGain must return Info Gain, Gain Significance and Probability of main class |
| |
| DROP TYPE IF EXISTS MADLIB_SCHEMA.findinfogain_type CASCADE; |
| CREATE TYPE MADLIB_SCHEMA.findinfogain_type AS( |
| value FLOAT[], |
| num_class INT, |
| num_values INT, |
| dim INT |
| ); |
| |
| DROP TYPE IF EXISTS MADLIB_SCHEMA.findinfogain_result CASCADE; |
| CREATE TYPE MADLIB_SCHEMA.findinfogain_result AS( |
| value1 FLOAT, |
| value2 FLOAT, |
| value3 FLOAT, |
| value4 FLOAT, |
| value5 FLOAT, |
| value6 FLOAT |
| ); |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.findinfogain_sfunc(MADLIB_SCHEMA.findinfogain_type, FLOAT, FLOAT, INT, INT, INT, INT) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.findinfogain_sfunc(MADLIB_SCHEMA.findinfogain_type, FLOAT, FLOAT, INT, INT, INT, INT) RETURNS MADLIB_SCHEMA.findinfogain_type AS $$ |
| begin |
| IF(($1.num_class IS NULL) OR ($1.num_class < 2)) THEN |
| $1.value = MADLIB_SCHEMA.mallocset(($4+1)*($5+1),0); |
| $1.num_class = $4; |
| $1.num_values = $5; |
| $1.dim = $7; |
| END IF; |
| $1.value = MADLIB_SCHEMA.aggr_InfoGain($1.value, $2, $3, $4, $5, $6); |
| RETURN $1; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.findinfogain_prefunc(MADLIB_SCHEMA.findinfogain_type, MADLIB_SCHEMA.findinfogain_type) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.findinfogain_prefunc(MADLIB_SCHEMA.findinfogain_type, MADLIB_SCHEMA.findinfogain_type) RETURNS MADLIB_SCHEMA.findinfogain_type AS $$ |
| begin |
| IF(($1.num_class IS NOT NULL) AND ($2.num_class IS NOT NULL)) THEN |
| $1.value = MADLIB_SCHEMA.array_add($1.value, $2.value); |
| RETURN $1; |
| END IF; |
| |
| IF($1.num_class IS NOT NULL) THEN |
| RETURN $1; |
| ELSE |
| RETURN $2; |
| END IF; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.findinfogain_finalfunc(MADLIB_SCHEMA.findinfogain_type) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.findinfogain_finalfunc(MADLIB_SCHEMA.findinfogain_type) RETURNS MADLIB_SCHEMA.findinfogain_result AS $$ |
| declare |
| temp FLOAT[]; |
| result MADLIB_SCHEMA.findinfogain_result; |
| begin |
| IF($1.value IS NOT NULL) THEN |
| temp = MADLIB_SCHEMA.compute_InfoGain($1.value, $1.num_class, $1.num_values); |
| result.value1 = COALESCE($1.dim,0); |
| result.value2 = temp[1]; |
| result.value3 = temp[2]; |
| result.value4 = temp[3]; |
| result.value5 = temp[4]; |
| result.value6 = $1.value[1]; |
| ELSE |
| result.value1 = COALESCE($1.dim, 0); |
| result.value2 = 0; |
| result.value3 = 1; |
| result.value4 = 1; |
| result.value5 = 0; |
| result.value6 = 0; |
| END IF; |
| return result; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.FindInfoGain(FLOAT, FLOAT, INT, INT, INT, INT); |
| CREATE AGGREGATE MADLIB_SCHEMA.FindInfoGain(FLOAT, FLOAT, INT, INT, INT, INT) ( |
| SFUNC=MADLIB_SCHEMA.findinfogain_sfunc, |
| m4_ifdef(`GREENPLUM',`prefunc=MADLIB_SCHEMA.findinfogain_prefunc,') |
| FINALFUNC=MADLIB_SCHEMA.findinfogain_finalfunc, |
| STYPE=MADLIB_SCHEMA.findinfogain_type |
| ); |
| |
| ---------------------- MADLIB_SCHEMA.MADLIB_SCHEMA.findentropy ---------- END |
| DROP TYPE IF EXISTS MADLIB_SCHEMA.res CASCADE; |
| CREATE TYPE MADLIB_SCHEMA.res AS( |
| feature INT, |
| probability FLOAT, |
| maxclass INTEGER, |
| infogain FLOAT, |
| live INT, |
| chisq FLOAT |
| ); |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.find_best_split(INT, INT, INT, INT, INT, TEXT); |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.find_best_split(feature_dimentions INT, distinct_classes INT, distinct_features INT, selection INT, sample_limit INT, table_name TEXT) RETURNS MADLIB_SCHEMA.res AS $$ |
| declare |
| sample_dimentions INT; |
| selected_dimentions INT[]; |
| tablesize INT; |
| i INT; |
| new_sample_limit FLOAT := sample_limit; |
| pre_result FLOAT[]; |
| result MADLIB_SCHEMA.res; |
| vdebug FLOAT[]; |
| hdebug INT; |
| total_size INT; |
| begin |
| --this computes how many dimentions need to samples to find one that is in 90th percentile with |
| -- .999 probability. |
| sample_dimentions = MADLIB_SCHEMA.min(floor(-ln(1-(.999)^(1/CAST(feature_dimentions AS FLOAT)))*feature_dimentions),feature_dimentions); |
| selected_dimentions = MADLIB_SCHEMA.WeightedNoReplacement(sample_dimentions, feature_dimentions); |
| |
| EXECUTE 'SELECT count(*) FROM ' || table_name || ' WHERE selection = ' || selection || ';' INTO total_size; |
| IF (new_sample_limit = 0) THEN |
| new_sample_limit = total_size; |
| END IF; |
| |
| DROP TABLE IF EXISTS selectedDimentionTableResults; |
| CREATE TEMP TABLE selectedDimentionTableResults( |
| dim INT, |
| infoGain FLOAT, |
| gainSign FLOAT, |
| classProb FLOAT, |
| classID FLOAT, |
| relativeSize FLOAT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (dim)'); |
| |
| DROP TABLE IF EXISTS selectedDimentionTable; |
| CREATE TEMP TABLE selectedDimentionTable( |
| dim INT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (dim)'); |
| |
| EXECUTE 'INSERT INTO selectedDimentionTable SELECT distinct unnest(ARRAY[ ' || array_to_string(selected_dimentions, ',') || ']);'; |
| EXECUTE 'SELECT count(*) FROM selectedDimentionTable' INTO tablesize; |
| |
| EXECUTE 'INSERT INTO selectedDimentionTableResults SELECT (g.t).* FROM (SELECT MADLIB_SCHEMA.FindInfoGain(MADLIB_SCHEMA.svec_proj(wp.feature, dt.dim), wp.weight, '||distinct_classes|| |
| ', '||distinct_features||', wp.class, dt.dim) AS t FROM (SELECT w.feature, w.weight, w.class FROM ' || |
| table_name || ' w WHERE w.selection = ' || selection || ' LIMIT ' || new_sample_limit || ') AS wp CROSS JOIN (SELECT * FROM ' || |
| ' selectedDimentionTable) AS dt WHERE MADLIB_SCHEMA.svec_proj(wp.feature, dt.dim) > 0 GROUP BY dt.dim) AS g;'; |
| |
| EXECUTE 'SELECT ARRAY[dt.dim, dt.infoGain, dt.gainSign, dt.classProb, dt.classID, dt.relativeSize] FROM selectedDimentionTableResults dt, (SELECT max(infoGain) AS maxClass FROM selectedDimentionTableResults) AS m WHERE dt.infoGain = m.maxClass AND dt.classID > 0 LIMIT 1' INTO pre_result; |
| |
| result.feature = pre_result[1]; |
| result.maxclass = pre_result[5]; |
| --EXECUTE 'SELECT count(*) FROM ' || table_name || ' WHERE selection = ' || selection || ' AND class = '|| result.maxclass ||';' INTO result.probability; |
| --result.probability/total_size; |
| result.probability = pre_result[4]; |
| result.infogain = pre_result[2]; |
| result.chisq = MADLIB_SCHEMA.chi2pdf(pre_result[3], distinct_features-1); |
| |
| IF (((result.chisq < 0.5/sample_dimentions) OR (MADLIB_SCHEMA.chi2pdf((pre_result[6]-total_size)*(pre_result[6]-total_size)/total_size, 1) < .1))AND(result.infogain > .1/sample_dimentions)) THEN |
| result.live = 1; |
| ELSE |
| result.live = 0; |
| END IF; |
| |
| RETURN result; |
| end |
| $$ language plpgsql; |
| |
| --------------------- JumpCalc ---------- START |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.jump_sfunc(INT[], INT, INT) CASCADE; |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.jump_sfunc(INT[], INT, INT) RETURNS INT[] AS $$ |
| declare |
| temp INT[]; |
| begin |
| temp = $1; |
| temp[$2+1] = $3; |
| RETURN temp; |
| end |
| $$ LANGUAGE plpgsql; |
| |
| DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.JumpCalc(INT, INT); |
| CREATE AGGREGATE MADLIB_SCHEMA.JumpCalc(INT, INT) ( |
| SFUNC=MADLIB_SCHEMA.jump_sfunc, |
| STYPE=INT[] |
| ); |
| |
| ---------------------- JumpCalc ---------- END |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.declare_tree_tables(); |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.declare_tree_tables() RETURNS void AS $$ begin |
| DROP TABLE IF EXISTS tree2 CASCADE; |
| CREATE TEMP TABLE tree2( |
| id SERIAL, |
| tree_location INT[], |
| hash INT, |
| feature INT, |
| probability FLOAT, |
| chisq FLOAT, |
| maxclass INTEGER, |
| infogain FLOAT, |
| live INT, |
| cat_size INT, |
| parent_id INT, |
| jump INT[] |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (id)'); |
| |
| DROP TABLE IF EXISTS MADLIB_SCHEMA.tree CASCADE; |
| CREATE TABLE MADLIB_SCHEMA.tree( |
| id SERIAL, |
| tree_location INT[], |
| hash INT, |
| feature INT, |
| probability FLOAT, |
| chisq FLOAT, |
| maxclass INTEGER, |
| infogain FLOAT, |
| live INT, |
| cat_size INT, |
| parent_id INT, |
| jump INT[] |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (id)'); |
| |
| DROP TABLE IF EXISTS finaltree CASCADE; |
| CREATE TEMP TABLE finaltree( |
| id INT, |
| new_id INT, |
| parent_id INT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (id)'); |
| |
| DROP TABLE IF EXISTS finaltree2 CASCADE; |
| CREATE TEMP TABLE finaltree2( |
| id INT, |
| new_id INT, |
| parent_id INT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (id)'); |
| end $$ LANGUAGE plpgsql; |
| |
| /** |
| * @brief Classify data points using trained decision tree model |
| * |
| * @param table_name Name of the table/view with the source data |
| * @param id_col_name Name of the column containing id of each point |
| * @param feature_col_name Name of the column containing feature vector of each point |
| * @param num_values Max number of classes that may be encountered |
| * @param verbosity If set to 1 will use verbose mode |
| * |
| * @return Table MADLIB_SCHEMA.classified_points: |
| * - <tt>id INT</tt> - Point id. |
| * - <tt>feature MADLIB_SCHEMA.SVEC</tt> - Point feature vector. |
| * - <tt>jump INT</tt> - Intermediate value used to distinguish leaf nodes. |
| * - <tt>class INT</tt> - Class prediction. |
| * - <tt>prob FLOAT</tt> - Probability of the predicted class. |
| * |
| * @note This function starts an iterative algorithm. It is not an aggregate |
| * function. Source and column names have to be passed as strings (due to |
| * limitations of the SQL syntax). |
| */ |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.Classify_Tree(TEXT, TEXT, TEXT, INT, INT); |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.Classify_Tree(table_name TEXT, id_col_name TEXT, feature_col_name TEXT, num_values INT, verbosity INT) RETURNS void AS $$ |
| declare |
| table_names TEXT[] = '{classified_points1,classified_points2}'; |
| jump_so_far INT := 0; |
| old_jump_so_far INT := 0; |
| table_pick INT := 1; |
| old_depth INT := 0; |
| new_depth INT := 0; |
| remains_to_classify INT; |
| size_finished INT; |
| begin |
| DROP TABLE IF EXISTS classified_points1; |
| CREATE TEMP TABLE classified_points1( |
| id INT, |
| feature MADLIB_SCHEMA.svec, |
| jump INT, |
| class INT, |
| prob FLOAT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (jump)'); |
| |
| DROP TABLE IF EXISTS classified_points2; |
| CREATE TEMP TABLE classified_points2( |
| id INT, |
| feature MADLIB_SCHEMA.svec, |
| jump INT, |
| class INT, |
| prob FLOAT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (jump)'); |
| |
| DROP TABLE IF EXISTS MADLIB_SCHEMA.classified_points; |
| CREATE TABLE MADLIB_SCHEMA.classified_points( |
| id INT, |
| feature MADLIB_SCHEMA.svec, |
| jump INT, |
| class INT, |
| prob FLOAT |
| ) m4_ifdef(`GREENPLUM',`DISTRIBUTED BY (jump)'); |
| |
| EXECUTE 'INSERT INTO classified_points1 (id, feature, jump, class, prob) SELECT '||id_col_name||', '||feature_col_name||', 1, 0, 0 FROM ' || table_name || ';'; |
| LOOP |
| EXECUTE 'SELECT id FROM MADLIB_SCHEMA.tree WHERE id > '|| jump_so_far ||' ORDER BY id LIMIT 1' INTO jump_so_far; |
| IF(verbosity > 0) THEN |
| RAISE INFO 'CLASSIFICATION STEP [%] CLASSIFIED: % TO BE CLASSIFIED: %', jump_so_far, size_finished, remains_to_classify; |
| END IF; |
| |
| --SELECT INTO new_depth array_upper(tree_location, 1) FROM MADLIB_SCHEMA.tree WHERE id = jump_so_far; |
| IF (jump_so_far IS NULL) THEN |
| new_depth = NULL; |
| ELSE |
| EXECUTE 'SELECT array_upper(tree_location, 1) FROM MADLIB_SCHEMA.tree WHERE id = ' ||jump_so_far|| ';' INTO new_depth; |
| END IF; |
| |
| IF(new_depth > old_depth) THEN |
| EXECUTE 'INSERT INTO MADLIB_SCHEMA.classified_points SELECT * FROM '|| table_names[(table_pick)%2+1] ||' WHERE jump = 0;'; |
| EXECUTE 'TRUNCATE '|| table_names[(table_pick)%2+1] ||';'; |
| EXECUTE 'SELECT count(*) FROM MADLIB_SCHEMA.classified_points;' INTO size_finished; |
| table_pick = table_pick%2+1; |
| END IF; |
| old_depth = new_depth; |
| |
| EXECUTE 'SELECT count(*) FROM '|| table_names[(table_pick)%2+1] ||';' INTO remains_to_classify; |
| IF ((jump_so_far IS NULL) OR (jump_so_far = old_jump_so_far) OR (remains_to_classify = 0)) THEN |
| EXIT; |
| END IF; |
| old_jump_so_far = jump_so_far; |
| |
| EXECUTE 'INSERT INTO '|| table_names[table_pick] ||' SELECT pt.id, pt.feature, COALESCE(gt.jump[MADLIB_SCHEMA.svec_proj(pt.feature, gt.feature)+1],0), gt.maxclass, gt.probability FROM (SELECT * FROM '|| |
| table_names[(table_pick)%2+1] ||' WHERE jump = '|| jump_so_far ||') AS pt, (SELECT * FROM MADLIB_SCHEMA.tree WHERE id = '|| jump_so_far||') AS gt;'; |
| END LOOP; |
| EXECUTE 'INSERT INTO MADLIB_SCHEMA.classified_points SELECT * FROM '|| table_names[table_pick] ||' WHERE jump = 0;'; |
| EXECUTE 'INSERT INTO MADLIB_SCHEMA.classified_points SELECT * FROM '|| table_names[table_pick%2+1] ||' WHERE jump = 0;'; |
| end |
| $$ language plpgsql; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.Classify_Tree(TEXT, TEXT, TEXT, INT); |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.Classify_Tree(table_name TEXT, id_col_name TEXT, feature_col_name TEXT, num_values INT) RETURNS void AS $$ |
| begin |
| PERFORM MADLIB_SCHEMA.Classify_Tree(table_name, id_col_name, feature_col_name, num_values, 0); |
| end $$ LANGUAGE plpgsql; |
| |
| DROP FUNCTION IF EXISTS MADLIB_SCHEMA.Cleanup_Tree(INT); |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.Cleanup_Tree(num_values INT) RETURNS void AS $$ |
| declare |
| tree_size INTEGER; |
| begin |
| TRUNCATE finaltree; |
| TRUNCATE finaltree2; |
| |
| --DELETE FROM tree2 WHERE COALESCE(cat_size,0) = 0; |
| EXECUTE 'DELETE FROM tree2 WHERE COALESCE(cat_size,0) = 0'; |
| |
| --SELECT INTO tree_size count(*) FROM tree2; |
| --INSERT INTO finaltree (id, parent_id, new_id) SELECT id, MAX(parent_id), (tree_size+1) - count(1) OVER(ORDER BY id DESC ROWS UNBOUNDED PRECEDING) FROM tree2 GROUP BY id; |
| --INSERT INTO finaltree2 (id, parent_id, new_id) SELECT g2.id,g.new_id,g2.new_id FROM finaltree g, finaltree g2 WHERE g.id = g2.parent_id; |
| |
| EXECUTE 'SELECT count(*) FROM tree2' INTO tree_size; |
| EXECUTE 'INSERT INTO finaltree (id, parent_id, new_id) SELECT id, MAX(parent_id), ('||tree_size||'+1) - count(1) OVER(ORDER BY id DESC ROWS UNBOUNDED PRECEDING) FROM tree2 GROUP BY id'; |
| EXECUTE 'INSERT INTO finaltree2 (id, parent_id, new_id) SELECT g2.id,g.new_id,g2.new_id FROM finaltree g, finaltree g2 WHERE g.id = g2.parent_id'; |
| |
| TRUNCATE finaltree; |
| TRUNCATE MADLIB_SCHEMA.tree; |
| |
| EXECUTE 'INSERT INTO MADLIB_SCHEMA.tree SELECT n.new_id, g.tree_location, g.hash, g.feature, g.probability, g.chisq, g.maxclass, g.infogain, g.live, g.cat_size, n.parent_id, g.jump FROM tree2 g, finaltree2 n WHERE n.id = g.id'; |
| EXECUTE 'INSERT INTO MADLIB_SCHEMA.tree SELECT * FROM tree2 WHERE id = 1'; |
| TRUNCATE tree2; |
| EXECUTE 'INSERT INTO tree2 (id, jump) SELECT parent_id, MADLIB_SCHEMA.JumpCalc(tree_location[array_upper(tree_location,1)], id) FROM MADLIB_SCHEMA.tree GROUP BY parent_id'; |
| TRUNCATE finaltree2; |
| EXECUTE 'UPDATE MADLIB_SCHEMA.tree k SET jump = g.jump FROM tree2 g WHERE g.id = k.id'; |
| TRUNCATE tree2; |
| end |
| $$ language plpgsql; |
| |
| /** |
| * @brief Train a decision tree model |
| * |
| * @param table_input Name of the table/view with the source data |
| * @param id_col_name Name of the column containing id of each point |
| * @param feature_col_name Name of the column containing feature vector of each point |
| * @param class_col_name Name of the column containing correct class of each point |
| * @param num_values Max number of classes that may be encountered |
| * @param max_num_iter Max number of branches to follow (e.g. 2000) |
| * @param verbosity If True (or 1) will run in verbose mode |
| * |
| * @return Table MADLIB_SCHEMA.tree: |
| * - <tt>id SERIAL</tt> - Tree node id |
| * - <tt>tree_location INT[]</tt> - Set of values that lead to this branch. |
| * 0 is the initial point (no value). But this path does not specify which |
| * feature was used for the branching. |
| * - <tt>hash INT</tt>: Hash value of the path. For quick unique identification. |
| * - <tt>feature INT</tt>: Which element of the feature vector was used for |
| * branching at this node. Notice that this feature is not used in the current |
| * <tt>tree_location</tt>, it will be added in the next step. |
| * - <tt>probability FLOAT</tt> - If forced to make a call for a dominant class |
| * at a given point this would be the confidence of the call (this is only an |
| * estimated value). |
| * - <tt>chisq FLOAT</tt> - Chi-square value of the branching significance, |
| * used to determine termination of the branch. |
| * - <tt>maxclass INTEGER</tt> - If forced to make a call for a dominant class |
| * at a given point this is the selected class. |
| * - <tt>infogain FLOAT</tt> - Information gain computed using entropy (at this |
| * node), also used to determine termination of the branch. |
| * - <tt>live INT</tt> - Indication that the branch is still growing. 1 means "live". |
| * Exit value may be 1 if number of branches reached before termination condition is met. |
| * - <tt>cat_size INT</tt> - Number of data point at this node. |
| * - <tt>parent_id INT</tt> - Id of the parent branch. |
| * - <tt>jump INT[]</tt> - Location of children for each feature value. Notice |
| * that assuming that the data is sparse we can have value 0 - representing |
| * no value. Result such as [2:3]={2,3}, should be read: |
| * jump['feature value'+1], so in this case there were no 0-value points for |
| * this feature. For value 1 jump to 2; for value 2 jump to 3; |
| * |
| * @note This function starts an iterative algorithm. It is not an aggregate |
| * function. Source and column names have to be passed as strings (due to |
| * limitations of the SQL syntax). |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.Train_Tree(table_input TEXT, id_col_name TEXT, feature_col_name TEXT, class_col_name TEXT, num_values INT, max_num_iter INT, verbosity INT) RETURNS void AS $$ |
| declare |
| feature_dimention INT; |
| dimensions FLOAT[]; |
| lifenodes INT; |
| selection INT; |
| sample_limit INT := 0; |
| location INT[]; |
| temp_location INT[]; |
| num_classes INT; |
| max_iter INT = max_num_iter; |
| answer MADLIB_SCHEMA.res; |
| location_size INT; |
| max_id INT; |
| flip INT := 1; |
| category_size FLOAT[]; |
| category_class INT; |
| table_names TEXT[] := '{weighted_points,weighted_points2}'; |
| time_stamp TIMESTAMP; |
| time_stamp2 TIMESTAMP; |
| misc_size INT; |
| begin |
| |
| PERFORM MADLIB_SCHEMA.declare_tree_tables(); |
| time_stamp2 = clock_timestamp(); |
| EXECUTE 'SELECT count(*) FROM '|| table_input ||';' INTO misc_size; |
| IF(verbosity > 0) THEN |
| RAISE INFO 'INPUT TABLE SIZE: %', misc_size; |
| END IF; |
| |
| PERFORM MADLIB_SCHEMA.remove_redundent(table_input, id_col_name, feature_col_name, class_col_name); |
| EXECUTE 'SELECT count(*) FROM weighted_points;' INTO misc_size; |
| IF(verbosity > 0) THEN |
| RAISE INFO 'TABLE SIZE AFTER COMPRESSION: %', misc_size; |
| END IF; |
| |
| EXECUTE 'SELECT MADLIB_SCHEMA.svec_dimension(feature) FROM ' || table_names[1] || ' LIMIT 1;' INTO feature_dimention; |
| EXECUTE 'SELECT COUNT(DISTINCT class) FROM ' || table_names[1] || ';' INTO num_classes; |
| IF(verbosity > 0) THEN |
| RAISE INFO 'NUMBER OF CLASSES IN THE TRAINING SET %', num_classes; |
| END IF; |
| IF(num_classes < 2)THEN |
| RAISE EXCEPTION 'Number of classes cannot be less than 1'; |
| END IF; |
| |
| EXECUTE 'INSERT INTO tree2 (tree_location, hash, feature, probability, chisq, maxclass, infogain, live, cat_size, parent_id) VALUES(ARRAY[0], MADLIB_SCHEMA.hash_array(ARRAY[0]), 0, 1, 1, 1, 1, 1, 0, 0)'; |
| location_size = 0; |
| |
| LOOP |
| --SELECT INTO lifenodes COUNT(*) FROM tree2 WHERE live = 1; |
| EXECUTE 'SELECT COUNT(*) FROM tree2 WHERE live = 1' INTO lifenodes; |
| IF((max_iter = 0) OR (lifenodes < 1)) THEN |
| IF(verbosity > 0) THEN |
| RAISE INFO 'EXIT: LIMIT % OR NO NODES LEFT', max_iter; |
| END IF; |
| EXIT; |
| END IF; |
| max_iter = max_iter-1; |
| --SELECT INTO selection id FROM tree2 WHERE live = 1 ORDER BY id LIMIT 1; |
| --SELECT INTO max_id id FROM tree2 WHERE live = 1 ORDER BY id DESC LIMIT 1; |
| --SELECT INTO location gt.tree_location FROM tree2 gt WHERE gt.id = selection; |
| |
| EXECUTE 'SELECT id FROM tree2 WHERE live = 1 ORDER BY id LIMIT 1' INTO selection; |
| EXECUTE 'SELECT id FROM tree2 WHERE live = 1 ORDER BY id DESC LIMIT 1' INTO max_id; |
| EXECUTE 'SELECT gt.tree_location FROM tree2 gt WHERE gt.id =' || selection ||';' INTO location; |
| |
| IF(location_size < array_upper(location,1)) THEN |
| flip = (flip)%2+1; |
| location_size = array_upper(location,1); |
| EXECUTE 'TRUNCATE TABLE ' || table_names[flip] || ';'; |
| END IF; |
| |
| EXECUTE 'SELECT ARRAY[COALESCE(count(*),0),COALESCE(sum(weight),0)] FROM ' || table_names[(flip%2)+1] || ' WHERE selection = ' || selection || ';' INTO category_size; |
| IF ((category_size[1] > 1) AND (category_size[2] > num_classes)) THEN |
| IF(verbosity > 0) THEN |
| RAISE INFO 'CURRENT SELECTION % CATEGORY SIZE %', selection, category_size[2]; |
| END IF; |
| answer = MADLIB_SCHEMA.find_best_split(feature_dimention, num_classes, num_values, selection, sample_limit, table_names[(flip)%2+1]); |
| time_stamp = clock_timestamp(); |
| |
| EXECUTE 'UPDATE tree2 SET feature = '||answer.feature||', |
| probability = '|| answer.probability||', |
| maxclass = '||answer.maxclass||', |
| infogain = '||answer.infogain||', |
| cat_size = '||category_size[2]||', |
| live = 0, |
| chisq = '||answer.chisq||' |
| WHERE id =' || selection || ';'; |
| |
| /* |
| EXECUTE 'UPDATE tree2 SET probability = '|| answer.probability|| ' WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET maxclass = '||answer.maxclass||' WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET infogain = '||answer.infogain||' WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET cat_size = '||category_size[2]||' WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET live = 0 WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET chisq = '||answer.chisq||' WHERE id = '||selection||';'; */ |
| |
| IF (answer.live > 0) THEN --here insert live determination function |
| FOR i IN 0..num_values LOOP |
| temp_location = location; |
| temp_location[array_upper(temp_location,1)+1] = i; |
| EXECUTE 'INSERT INTO tree2 (tree_location, hash, feature, probability, maxclass, infogain, live, parent_id) VALUES(ARRAY['||array_to_string(temp_location, ',')||'], ' || MADLIB_SCHEMA.hash_array(temp_location) || ', 0, 1, 1, 1, 1, '|| selection ||');'; |
| END LOOP; |
| EXECUTE 'INSERT INTO ' || table_names[flip] || ' SELECT id, feature, class, weight, MADLIB_SCHEMA.svec_proj(feature,' || answer.feature || ') + ' || max_id+1 || ' FROM ' || |
| table_names[(flip%2)+1] || ' WHERE selection = ' || selection || ';'; |
| END IF; |
| ELSE |
| --UPDATE tree2 SET live = 0 WHERE id = selection; |
| EXECUTE 'UPDATE tree2 SET live = 0 WHERE id = '||selection||';'; |
| IF (category_size[2] > num_classes) THEN |
| EXECUTE 'SELECT max(class) FROM ' || table_names[(flip%2)+1] || ' WHERE selection = ' || selection || ';' INTO category_class; |
| |
| EXECUTE 'INSERT INTO tree2 (tree_location, hash, feature, probability, maxclass, infogain, live, parent_id) VALUES(ARRAY['|| array_to_string(location,',') ||'], |
| MADLIB_SCHEMA.hash_array(ARRAY['|| array_to_string(temp_location,',') ||']), 0, 1, 1, 1, 1,'||selection||')'; |
| |
| EXECUTE 'UPDATE tree2 SET feature = 1, |
| probability = 1.0, |
| maxclass = '||category_class||', |
| infogain = 0, |
| cat_size = '||category_size[2]||' |
| WHERE id = '|| selection||';'; |
| |
| /* |
| EXECUTE 'UPDATE tree2 SET probability = 1.0 WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET chisq = 1.0 WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET maxclass = '||category_class||' WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET infogain = 0 WHERE id = '||selection||';'; |
| EXECUTE 'UPDATE tree2 SET cat_size = '||category_size[2]||' WHERE id = '||selection||';';*/ |
| |
| ELSE |
| --DELETE FROM tree2 WHERE id = selection; |
| EXECUTE 'DELETE FROM tree2 WHERE id = '||selection||';'; |
| END IF; |
| END IF; |
| END LOOP; |
| EXECUTE 'SELECT MADLIB_SCHEMA.Cleanup_Tree(' || num_values || ');'; |
| IF(verbosity > 0) THEN |
| RAISE INFO '-------> FINAL TIME %' , (clock_timestamp() - time_stamp2); |
| END IF; |
| end |
| $$ language plpgsql; |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.Train_Tree(table_input TEXT, id_col_name TEXT, feature_col_name TEXT, class_col_name TEXT, num_values INT, max_num_iter INT) RETURNS void AS $$ |
| begin |
| PERFORM MADLIB_SCHEMA.Train_Tree(table_input, id_col_name, feature_col_name, class_col_name, num_values, max_num_iter, 0); |
| end |
| $$ language plpgsql; |