Decision Tree: allow flexible id column name
JIRA: MADLIB-847
diff --git a/methods/cart/src/pg_gp/dt.sql_in b/methods/cart/src/pg_gp/dt.sql_in
index c74cc30..ba54eae 100644
--- a/methods/cart/src/pg_gp/dt.sql_in
+++ b/methods/cart/src/pg_gp/dt.sql_in
@@ -1079,7 +1079,6 @@
encoded_table_name TEXT;
metatable_name TEXT;
curstmt TEXT;
- id_col_name TEXT;
class_col_name TEXT;
classify_result TEXT;
temp_text TEXT;
@@ -1088,7 +1087,6 @@
swap_tree_table TEXT;
BEGIN
metatable_name = tree_table_name || '_di';
- id_col_name = MADLIB_SCHEMA.__get_id_column_name(metatable_name);
class_col_name = MADLIB_SCHEMA.__get_class_column_name(metatable_name);
-- the value of class column in validation table must in the KV table
@@ -1160,7 +1158,7 @@
%
) AS g
FROM % c, % s
- WHERE c.id=s.%
+ WHERE c.id=s.id
GROUP BY parent_id
) t
WHERE t.g[2] >= 0 AND
@@ -1179,7 +1177,6 @@
MADLIB_SCHEMA.__to_char(max_num_classes),
classify_result,
encoded_table_name,
- id_col_name,
tree_table_name,
tree_table_name
]
diff --git a/methods/cart/src/pg_gp/dt_preproc.sql_in b/methods/cart/src/pg_gp/dt_preproc.sql_in
index 8fb69fc..ad30d1c 100644
--- a/methods/cart/src/pg_gp/dt_preproc.sql_in
+++ b/methods/cart/src/pg_gp/dt_preproc.sql_in
@@ -1107,7 +1107,7 @@
curstmt = MADLIB_SCHEMA.__format
(
'INSERT INTO %(id, fid, fval, is_cont, class)
- SELECT id, fid, % as fval, is_cont, class
+ SELECT %, fid, % as fval, is_cont, class
FROM
(
SELECT %, generate_series(1, %) as fid,
@@ -1119,6 +1119,7 @@
%',
ARRAY[
breakup_tbl_name,
+ id_col_name,
fval_txt,
id_col_name,
array_upper(is_conts, 1)::TEXT,
@@ -1138,7 +1139,7 @@
curstmt = MADLIB_SCHEMA.__format
(
'INSERT INTO %(id, fid, fval, is_cont, class)
- SELECT id, fid, % as fval, is_cont, class
+ SELECT %, fid, % as fval, is_cont, class
FROM
(
SELECT %, generate_series(1, %) as fid,
@@ -1152,6 +1153,7 @@
%',
ARRAY[
breakup_tbl_name,
+ id_col_name,
fval_txt,
id_col_name,
array_upper(is_conts, 1)::TEXT,
diff --git a/methods/cart/src/pg_gp/sql/dt_test.sql_in b/methods/cart/src/pg_gp/sql/dt_test.sql_in
index 53e4951..bbdb8a0 100644
--- a/methods/cart/src/pg_gp/sql/dt_test.sql_in
+++ b/methods/cart/src/pg_gp/sql/dt_test.sql_in
@@ -2,13 +2,13 @@
DROP TABLE IF EXISTS golf_dt_test;
CREATE TABLE golf_dt_test (
- id INT,
+ xid INT,
outlook TEXT,
temperature DOUBLE PRECISION,
humidity DOUBLE PRECISION,
windy TEXT,
class TEXT
-) m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)');
+) m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (xid)');
INSERT INTO golf_dt_test VALUES (1, 'sunny', 85, 85, ' false', ' Do not Play');
INSERT INTO golf_dt_test VALUES (3, 'overcast', 83, 78, ' false', ' Play');
@@ -537,7 +537,7 @@
INSERT INTO crx_dt_test VALUES (488, 'b', 24.5, 13.335000000000001, 'y', 'p', 'aa', 'v', 0.040000000000000001, 'f', 'f', 0, 't', 'g', 120, 475, '-');
INSERT INTO crx_dt_test VALUES (490, NULL, 45.329999999999998, 1, 'u', 'g', 'q', 'v', 0.125, 'f', 'f', 0, 't', 'g', 263, 0, '-');
-DROP FUNCTION IF EXISTS c45_test(TEXT,TEXT,TEXT,TEXT,TEXT,TEXT,INT,TEXT,TEXT);
+DROP FUNCTION IF EXISTS c45_test(TEXT,TEXT,TEXT,TEXT,TEXT,TEXT,INT,TEXT,TEXT,TEXT);
CREATE OR REPLACE FUNCTION c45_test
(
training_set TEXT,
@@ -548,7 +548,8 @@
validation_set TEXT,
tree_dep INT,
missing_operation TEXT,
- exp_display_result TEXT
+ exp_display_result TEXT,
+ id TEXT
)
RETURNS TEXT AS $$
declare
@@ -563,7 +564,7 @@
begin
train_result = c45_train(method,training_set,tree_name,
- validation_set,cont_features,null,'id',class_column, 100,
+ validation_set,cont_features,null,id,class_column, 100,
missing_operation,tree_dep,0,0,0);
-- ensure we didn't change the variable names of the returned type
EXECUTE 'SELECT COUNT(*) = 5 FROM pg_attribute
@@ -617,6 +618,22 @@
end
$$ language plpgsql;
+DROP FUNCTION IF EXISTS c45_test(TEXT,TEXT,TEXT,TEXT,TEXT,TEXT,INT,TEXT,TEXT);
+CREATE OR REPLACE FUNCTION c45_test
+ (
+ training_set TEXT,
+ method TEXT,
+ cont_features TEXT,
+ class_column TEXT,
+ score_set TEXT,
+ validation_set TEXT,
+ tree_dep INT,
+ missing_operation TEXT,
+ exp_display_result TEXT
+ )
+RETURNS TEXT AS $$
+ SELECT c45_test($1, $2, $3, $4, $5, $6, $7, $8, $9, 'id');
+$$ language sql;
-- Golf dataset is very small. Since random forest uses sampling, we think this dataset
-- is not appropriate for random forest test. We only test c4.5 based on golf.
SELECT c45_test('golf_dt_test','infogain','temperature,humidity',
@@ -629,7 +646,7 @@
windy: = true : class( Do not Play) num_elements(2) predict_prob(1)
outlook: = sunny : class( Do not Play) num_elements(5) predict_prob(0.6)
humidity: <= 70 : class( Play) num_elements(2) predict_prob(1)
- humidity: > 70 : class( Do not Play) num_elements(3) predict_prob(1)');
+ humidity: > 70 : class( Do not Play) num_elements(3) predict_prob(1)', 'xid');
SELECT c45_test('golf_dt_test','gini','temperature,humidity',
@@ -642,7 +659,7 @@
windy: = true : class( Do not Play) num_elements(2) predict_prob(1)
outlook: = sunny : class( Do not Play) num_elements(5) predict_prob(0.6)
humidity: <= 70 : class( Play) num_elements(2) predict_prob(1)
- humidity: > 70 : class( Do not Play) num_elements(3) predict_prob(1)');
+ humidity: > 70 : class( Do not Play) num_elements(3) predict_prob(1)', 'xid');
-- Verify temporary table can be used as training table.
CREATE TABLE golf_tmp AS SELECT * from golf_dt_test;
@@ -666,7 +683,7 @@
temperature: > 72 : class( Play) num_elements(2) predict_prob(1)
temperature: > 75 : class( Do not Play) num_elements(1) predict_prob(1)
temperature: > 80 : class( Play) num_elements(2) predict_prob(1)
- temperature: > 83 : class( Do not Play) num_elements(1) predict_prob(1)');
+ temperature: > 83 : class( Do not Play) num_elements(1) predict_prob(1)', 'xid');
SELECT c45_test('crx_dt_test','gainratio','A2,A3,A8,A11,A14,A15',
'a16','crx_dt_test','crx_dt_test',10,'explicit',