| # coding=utf-8 |
| """ |
| @file random_forest.py_in |
| |
| @brief Random Forest: Driver functions |
| |
| @namespace random_forest |
| """ |
| |
| import plpy |
| from math import sqrt |
| |
| from utilities.control import MinWarning |
| from utilities.control import EnableOptimizer |
| from utilities.control import EnableHashagg |
| from utilities.validate_args import get_cols_and_types |
| from utilities.validate_args import is_var_valid |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.utilities import _assert |
| from utilities.utilities import unique_string |
| from utilities.utilities import add_postfix |
| from utilities.utilities import split_quoted_delimited_str |
| from utilities.utilities import extract_keyvalue_params |
| |
| from decision_tree import _tree_train_using_bins |
| from decision_tree import _tree_train_grps_using_bins |
| from decision_tree import _get_bins |
| from decision_tree import _get_bins_grps |
| from decision_tree import _get_features_to_use |
| from decision_tree import _get_dep_type |
| from decision_tree import _is_dep_categorical |
| from decision_tree import _get_n_and_deplist |
| from decision_tree import _classify_features |
| from decision_tree import _get_filter_str |
| from decision_tree import _dict_get_quoted |
| from decision_tree import _get_display_header |
| from decision_tree import get_feature_str |
| # ------------------------------------------------------------ |
| |
| |
| def forest_train_help_message(schema_madlib, message, **kwargs): |
| """ Help message for Random Forest |
| """ |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Random Forest |
| |
| Random forests use a forest-based predictive model to |
| predict the value of a target variable based on several input variables. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.forest_train('usage'); |
| For an example on using this function: |
| SELECT {schema_madlib}.forest_train('example'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.forest_train( |
| 'training_table', -- Data table name |
| 'output_table', -- Table name to store the forest model |
| 'id_col_name', -- Row ID, used in forest_predict |
| 'dependent_variable', -- The column to fit |
| 'list_of_features', -- Comma separated column names to be |
| used as the predictors, can be '*' |
| to include all columns except the |
| dependent_variable |
| 'features_to_exclude', -- Comma separated column names to be |
| excluded if list_of_features is '*' |
| 'grouping_cols', -- Comma separated column names used to |
| group the data. A random forest model |
| will be created for each group. Default |
| is NULL |
| num_trees, -- Integer, default: 100. Maximum number of trees |
| to grow in the Random Forest model. Actual |
| number of trees grown may be slighlty different. |
| num_random_features, -- Integer, default: sqrt(n) if classification tree, |
| otherwise n/3. Number of features to randomly |
| select at each split. |
| importance, -- Boolean, whether to calculate variable importance, |
| default is True |
| num_permutations -- Number of times to permute each feature value while |
| calculating variable importance, default is 1. |
| max_tree_depth, -- Maximum depth of any node, default is 10 |
| min_split, -- Minimum number of observations that must |
| exist in a node for a split to be |
| attemped, default is 20 |
| min_bucket, -- Minimum number of observations in any |
| terminal node, default is min_split/3 |
| num_splits, -- Number of bins to find possible node |
| split threshold values for continuous |
| variables, default is 100 (Must be greater than 1) |
| 'surrogate_params', -- Text, Comma-separated string of key-value pairs |
| controlling the behavior of surrogate splits for |
| each node in a tree. 'max_surrogates=n', where n |
| is an positive integer with the default 0. |
| verbose, -- Boolean, whether to print more info, |
| default is False |
| sample_ratio -- Double precision, in the range of (0, 1], default: 1 |
| If sample_ratio is less than 1, a bootstrap sample |
| size smaller than the data table is expected to be |
| used for training each tree in the forest. |
| ); |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the following columns ( |
| quoted items are of text type.): |
| gid -- Integer. group id that uniquely identifies |
| a set of grouping column values. |
| sample_id -- Integer. id of the bootstrap sample that |
| this tree is a part of. |
| tree -- bytea8. trained tree model stored in binary |
| format. |
| |
| The output summary table ('output_table_summary') has the following |
| columns: |
| 'method' -- Method name: 'forest_train' |
| is_classification -- boolean. True if it is a classification model. |
| 'source_table' -- Data table name |
| 'model_table' -- Tree model table name |
| 'id_col_name' -- The ID column name |
| 'dependent_varname' -- Response variable column name |
| 'independent_varnames' -- Comma-separated feature column names |
| 'cat_features' -- Comma-separated column names of categorical variables |
| 'con_features' -- Comma-separated column names of continuous variables |
| 'grouping_cols' -- Grouping column names |
| num_trees -- Number of trees grown by the model |
| num_random_features -- Number of features randomly selected for each split. |
| max_tree_depth -- Maximum depth of any node. |
| min_split -- Minimum number of observations that must |
| exist in a node for a split to be |
| attemped. |
| min_bucket -- Minimum number of observations in any |
| terminal node. |
| num_splits -- Number of bins to find possible node |
| split threshold values for continuous |
| variables. |
| verbose -- Boolean, whether to print more info. |
| importance -- Boolean, whether to calculate variable importance, |
| num_permutations -- Number of times to permute each feature value while |
| calculating variable importance |
| num_all_groups -- Int. Number of groups during forest training. |
| num_failed_groups -- Number of groups for which training failed |
| total_rows_processed -- Number of rows used in the model training |
| total_rows_skipped -- Number of rows skipped because NULL values |
| dependent_var_levels -- For classification, the distinct levels of |
| the dependent variable |
| dependent_var_type -- The type of dependent variable |
| |
| A third table contains the information about the grouping ('output_table_group') and |
| it has the following columns: |
| gid -- Integer. group id that uniquely identifies a set |
| of grouping column values. |
| <...> -- Grouping columns, if provided in input. Same type as |
| in the training data table. This could be multiple |
| columns depending on the grouping_cols input. |
| success -- Boolean, indicator of the success of the group. |
| cat_levels_in_text -- text[]. Ordered levels of categorical variables. |
| cat_n_levels -- integer[]. Number of levels for each categorical variable. |
| oob_error -- double precision. Out-of-bag error for the random forest model. |
| cat_var_importance -- double precision[]. Variable importance for categorical |
| features. The order corresponds to the order of the |
| variables as found in cat_features in <output_table>_summary. |
| con_var_importance -- double precision[]. Variable importance for continuous |
| features. The order corresponds to the order of the |
| variables as found in con_features in <model_table>_summary. |
| """ |
| elif message.lower().strip() in ['example', 'examples']: |
| help_string = """ |
| ------------------------------------------------------------ |
| EXAMPLE |
| ------------------------------------------------------------ |
| DROP TABLE IF EXISTS dt_golf; |
| CREATE TABLE dt_golf ( |
| id integer NOT NULL, |
| "OUTLOOK" text, |
| temperature double precision, |
| humidity double precision, |
| windy text, |
| class text |
| ); |
| |
| INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,windy,class) VALUES |
| (1, 'sunny', 85, 85, 'false', 'Don''t Play'), |
| (2, 'sunny', 80, 90, 'true', 'Don''t Play'), |
| (3, 'overcast', 83, 78, 'false', 'Play'), |
| (4, 'rain', 70, 96, 'false', 'Play'), |
| (5, 'rain', 68, 80, 'false', 'Play'), |
| (6, 'rain', 65, 70, 'true', 'Don''t Play'), |
| (7, 'overcast', 64, 65, 'true', 'Play'), |
| (8, 'sunny', 72, 95, 'false', 'Don''t Play'), |
| (9, 'sunny', 69, 70, 'false', 'Play'), |
| (10, 'rain', 75, 80, 'false', 'Play'), |
| (11, 'sunny', 75, 70, 'true', 'Play'), |
| (12, 'overcast', 72, 90, 'true', 'Play'), |
| (13, 'overcast', 81, 75, 'false', 'Play'), |
| (14, 'rain', 71, 80, 'true', 'Don''t Play'); |
| |
| DROP TABLE IF EXISTS train_output, train_output_group, train_output_summary; |
| SELECT madlib.forest_train('dt_golf', -- source table |
| 'train_output', -- output model table |
| 'id', -- id column |
| 'class', -- response |
| '"OUTLOOK", temperature, humidity, windy', -- features |
| NULL, -- exclude columns |
| NULL, -- grouping columns |
| 20::integer, -- number of trees |
| 2::integer, -- number of random features |
| TRUE::boolean, -- variable importance |
| 1::integer, -- num_permutations |
| 8::integer, -- max depth |
| 3::integer, -- min split |
| 1::integer, -- min bucket |
| 10::integer -- number of splits per continuous variable |
| ); |
| SELECT madlib.get_tree('train_output',1,2,FALSE); |
| |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.forest_train('usage')" |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |
| |
| |
| def forest_train( |
| schema_madlib, training_table_name, output_table_name, id_col_name, |
| dependent_variable, list_of_features, list_of_features_to_exclude, |
| grouping_cols, num_trees, num_random_features, |
| importance, num_permutations, max_tree_depth, |
| min_split, min_bucket, num_bins, |
| surrogate_params, verbose=False, sample_ratio=None, **kwargs): |
| """ Random forest main training function |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param training_table_name: str, source table name |
| @param output_table_name: str, model table name |
| @param id_col_name: str, id column name to uniquely identify each row |
| @param dependent_variable: str, dependent variable column name |
| @param list_of_features: str, Comma-separated list of feature column names, |
| can also be '*' implying all columns |
| except dependent_variable |
| @param list_of_features_to_exclude: str, features to exclude if '*' is used |
| @param grouping_cols: str, List of grouping columns to group the data |
| @param num_trees: int, Number of trees in the forest |
| @param num_random_features: int, Number of random features used in spliting nodes |
| @param importance: boolean, Whether or not to calculate variable importance |
| @param num_permutations: int, Number of times to permute each feature value |
| during calculation of variable importance |
| @param max_tree_depth: int, Maximum depth of each tree |
| @param min_split: int, Minimum tuples in a node before splitting it |
| @param min_bucket: int, Minimum tuples in each child before splitting a node |
| @param num_bins: int, Number of bins for quantizing a continuous variables |
| @param verbose: str, Verbosity of output messages |
| @param sample_ratio: float, subsampling ratio for generating src_view |
| """ |
| msg_level = "'notice'" if verbose else "'warning'" |
| |
| with MinWarning(msg_level): |
| with EnableOptimizer(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with EnableHashagg(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| ################################################################## |
| #### set default values |
| if grouping_cols is not None and grouping_cols.strip() == '': |
| grouping_cols = None |
| num_trees = 100 if num_trees is None else num_trees |
| max_tree_depth = 10 if max_tree_depth is None else max_tree_depth |
| min_split = 20 if min_split is None and min_bucket is None else min_split |
| min_bucket = min_split // 3 if not min_bucket else min_bucket |
| min_split = min_bucket * 3 if not min_split else min_split |
| num_bins = 100 if num_bins is None else num_bins |
| sample_ratio = 1 if sample_ratio is None else sample_ratio |
| |
| surrogate_param_dict = extract_keyvalue_params( |
| surrogate_params, |
| dict(max_surrogates=int), # type of variable |
| dict(max_surrogates=0)) # default value of variable |
| max_n_surr = surrogate_param_dict['max_surrogates'] |
| _assert(max_n_surr >= 0, |
| "Maximum number of surrogates ({0}) should be non-negative". |
| format(max_n_surr)) |
| |
| ################################################################## |
| # validate arguments |
| _forest_validate_args(training_table_name, output_table_name, id_col_name, |
| list_of_features, dependent_variable, |
| list_of_features_to_exclude, grouping_cols, |
| num_trees, num_random_features, |
| num_permutations, max_tree_depth, |
| min_split, min_bucket, num_bins, sample_ratio) |
| |
| ################################################################## |
| # preprocess arguments |
| # expand "*" syntax and exclude some features |
| features = _get_features_to_use(schema_madlib, |
| training_table_name, |
| list_of_features, |
| list_of_features_to_exclude, |
| id_col_name, |
| '1', dependent_variable, |
| grouping_cols) |
| |
| _assert(bool(features), |
| "Random forest error: No feature is selected for the model.") |
| |
| is_classification, is_bool = _is_dep_categorical( |
| training_table_name, dependent_variable) |
| split_criterion = 'gini' if is_classification else 'mse' |
| |
| if num_random_features is None: |
| n_all_features = len(features) |
| num_random_features = (sqrt(n_all_features) if is_classification |
| else n_all_features / 3) |
| _assert(num_random_features <= len(features), |
| "Random forest error: Number of features to be selected " |
| "is more than the actual number of features.") |
| |
| all_cols_types = dict(get_cols_and_types(training_table_name)) |
| cat_features, ordered_cat_features, con_features, boolean_cats = \ |
| _classify_features(all_cols_types, features) |
| |
| filter_null = _get_filter_str(dependent_variable, grouping_cols) |
| # the total number of records |
| n_all_rows = plpy.execute("SELECT count(*) FROM {0}". |
| format(training_table_name))[0]['count'] |
| |
| if is_classification: |
| # For classifications, we also need to map dependent_variable to integers |
| n_rows, dep_list = _get_n_and_deplist(training_table_name, |
| dependent_variable, |
| filter_null) |
| _assert(n_rows > 0, |
| "Random forest error: There should be at least one " |
| "data point for each class where all features are non NULL") |
| dep_list.sort() |
| dep_col_str = ("CASE WHEN " + dependent_variable + |
| " THEN 'True' ELSE 'False' END") if is_bool else dependent_variable |
| dep = ("(CASE " + |
| "\n ".join([ |
| "WHEN ({dep_col})::text = $${c}$$ THEN {i}".format( |
| dep_col=dep_col_str, c=c, i=i) |
| for i, c in enumerate(dep_list)]) + |
| "\nEND)") |
| dep_n_levels = len(dep_list) |
| else: |
| n_rows = plpy.execute( |
| "SELECT count(*) FROM {source_table} where {filter_null}". |
| format(source_table=training_table_name, |
| filter_null=filter_null))[0]['count'] |
| dep = dependent_variable |
| dep_n_levels = 1 |
| dep_list = None |
| |
| # a table that maps gid/grp_key to actual columns |
| grp_key_to_grp_cols = unique_string() |
| # create the above table and perform binning |
| if grouping_cols is None: |
| sql_grp_key_to_grp_cols = """ |
| CREATE TABLE {grp_key_to_grp_cols} AS |
| SELECT ''::text AS grp_key, 1 AS gid |
| """.format(**locals()) |
| plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols) |
| plpy.execute(sql_grp_key_to_grp_cols) |
| |
| # find the bins, one dict containing two arrays: categorical |
| # bins, and continuous bins |
| num_groups = 1 |
| bins = _get_bins(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, num_bins, dep, |
| boolean_cats, n_rows, is_classification, |
| dep_n_levels, filter_null) |
| # some features may be dropped because they have only one value |
| cat_features = bins['cat_features'] |
| bins['grp_key_cat'] = [''] |
| else: |
| grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] |
| grouping_cols_and_types = [(col, _dict_get_quoted(all_cols_types, col)) |
| for col in grouping_cols_list] |
| grouping_array_str = ( |
| "array_to_string(array[" + |
| ','.join("(case when " + col + " then 'True' else 'False' end)::text" |
| if col_type == 'boolean' else '(' + col + ')::text' |
| for col, col_type in grouping_cols_and_types) + |
| "], ',')") |
| grouping_cols_str = ('' if grouping_cols is None |
| else grouping_cols + ",") |
| sql_grp_key_to_grp_cols = """ |
| CREATE TABLE {grp_key_to_grp_cols} AS |
| SELECT |
| {grouping_cols}, |
| {grouping_array_str} AS grp_key, |
| (row_number() over ())::integer AS gid |
| FROM {training_table_name} |
| GROUP BY {grouping_cols} |
| """.format(**locals()) |
| plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols) |
| plpy.execute(sql_grp_key_to_grp_cols) |
| |
| # find bins |
| num_groups = plpy.execute(""" |
| SELECT count(*) FROM {grp_key_to_grp_cols} |
| """.format(**locals()))[0]['count'] |
| plpy.notice("Analyzing data to compute split boundaries for variables") |
| bins = _get_bins_grps(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, num_bins, dep, |
| boolean_cats, grouping_cols, |
| grouping_array_str, n_rows, |
| is_classification, dep_n_levels, filter_null) |
| cat_features = bins['cat_features'] |
| |
| # a table for converting cat_features to integers |
| cat_features_info_table = unique_string() |
| sql_cat_features_info = """ |
| CREATE TEMP TABLE {cat_features_info_table} AS |
| SELECT |
| gid, |
| cat_n_levels, |
| cat_levels_in_text |
| FROM |
| ( |
| SELECT * |
| FROM {schema_madlib}._gen_cat_levels_set($1, $2, $3, $4) |
| ) subq |
| JOIN |
| {grp_key_to_grp_cols} |
| USING (grp_key) |
| """.format(**locals()) |
| plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info) |
| plan_cat_features_info = plpy.prepare( |
| sql_cat_features_info, ['text[]', 'integer[]', 'integer', 'text[]']) |
| plpy.execute(plan_cat_features_info, [ |
| bins['grp_key_cat'], |
| bins['cat_n'], |
| len(cat_features), |
| bins['cat_origin']]) |
| |
| con_splits_table = unique_string() |
| _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, grp_key_to_grp_cols, bins) |
| |
| ################################################################## |
| # create views and tables for training (growing) of trees |
| # store the prediction for all oob samples |
| # for classification, the prediction is of integer type here |
| oob_prediction_table = unique_string() |
| sql_create_oob_prediction_table = """ |
| CREATE TEMP TABLE {oob_prediction_table} AS |
| SELECT |
| {id_col_name}, |
| 1 AS sample_id, |
| 1 AS gid, |
| {dep} AS dep, |
| {dep} AS oob_prediction, |
| ARRAY[1.0]::float8[] AS cat_imp_score, |
| ARRAY[1.0]::float8[] AS con_imp_score |
| FROM {training_table_name} |
| LIMIT 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_oob_prediction_table:\n" + sql_create_oob_prediction_table) |
| plpy.execute(sql_create_oob_prediction_table) |
| |
| # to store poisson count defining bootstrap sample |
| training_pois_cnt_table = unique_string() |
| subsample_random_column = unique_string() |
| sql_create_training_pois_cnt = """ |
| CREATE TEMP TABLE {training_pois_cnt_table} AS |
| SELECT |
| *, |
| 1.::double precision, |
| {schema_madlib}.poisson_random(1) AS poisson_count |
| FROM {training_table_name} |
| LIMIT 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_training_pois_cnt:\n" + sql_create_training_pois_cnt) |
| plpy.execute(sql_create_training_pois_cnt) |
| |
| # views dependent on current bootstrap sample |
| src_view = unique_string() |
| sql_create_src_view = """ |
| CREATE VIEW {src_view} AS |
| SELECT * |
| FROM {training_pois_cnt_table} |
| WHERE poisson_count != 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_src_view:\n" + sql_create_src_view) |
| plpy.execute(sql_create_src_view) |
| |
| oob_view = unique_string() |
| sql_create_oob_view = """ |
| CREATE VIEW {oob_view} AS |
| SELECT * |
| FROM {training_pois_cnt_table} |
| WHERE poisson_count = 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_oob_view:\n" + sql_create_oob_view) |
| plpy.execute(sql_create_oob_view) |
| |
| _create_empty_result_table(schema_madlib, output_table_name) |
| |
| ################################################################## |
| # training random forest |
| tree_terminated = None |
| for sample_id in range(1, num_trees + 1): |
| if 1 - sample_ratio < 1e-6: |
| random_sample_expr = "0.::double precision" |
| else: |
| random_sample_expr = "random()" |
| |
| sql_refresh_training_pois_cnt = """ |
| TRUNCATE TABLE {training_pois_cnt_table} CASCADE; |
| INSERT INTO {training_pois_cnt_table} |
| SELECT |
| *, |
| {schema_madlib}.poisson_random(1) AS poisson_count |
| FROM |
| ( |
| SELECT |
| *, |
| {random_sample_expr} AS {subsample_random_column} |
| FROM {training_table_name} |
| ) subq |
| WHERE {subsample_random_column} < {sample_ratio} |
| """.format(**locals()) |
| plpy.notice("sql_refresh_training_pois_cnt:\n" + sql_refresh_training_pois_cnt) |
| plpy.execute(sql_refresh_training_pois_cnt) |
| |
| if verbose: |
| tup_cnt_in_view = plpy.execute(""" |
| SELECT |
| count(*) AS c, |
| sum(poisson_count) AS s |
| FROM {src_view} |
| """.format(**locals()))[0] |
| src_cnt = tup_cnt_in_view['c'] |
| dup_cnt = tup_cnt_in_view['s'] |
| oob_cnt = plpy.execute(""" |
| SELECT count(*) AS c FROM {oob_view} |
| """.format(**locals()))[0]['c'] |
| plpy.notice(""" |
| src_cnt: {src_cnt}, |
| oob_cnt: {oob_cnt}, |
| dup_cnt: {dup_cnt}. |
| """.format(**locals())) |
| |
| if grouping_cols is None: # non-grouping case |
| tree_state = _tree_train_using_bins( |
| schema_madlib, bins, src_view, cat_features, con_features, |
| boolean_cats, num_bins, 'poisson_count', dep, min_split, |
| min_bucket, max_tree_depth, filter_null, dep_n_levels, |
| is_classification, split_criterion, True, |
| num_random_features, max_n_surr) |
| tree_states = [dict(tree_state=tree_state['tree_state'], |
| grp_key='')] |
| |
| tree_terminated = {'': tree_state['finished']} |
| |
| else: |
| tree_states = _tree_train_grps_using_bins( |
| schema_madlib, bins, src_view, cat_features, con_features, |
| boolean_cats, num_bins, 'poisson_count', grouping_cols, |
| grouping_array_str, dep, min_split, min_bucket, |
| max_tree_depth, filter_null, dep_n_levels, |
| is_classification, split_criterion, True, |
| num_random_features, tree_terminated=tree_terminated, |
| max_n_surr=max_n_surr) |
| |
| # If a tree for a group is terminated (not finished properly), |
| # then we do not need to compute other trees, and can just |
| # stop calculating that group further. |
| if tree_terminated is None: |
| tree_terminated = dict((item['grp_key'], item['finished']) |
| for item in tree_states) |
| else: |
| for item in tree_states: |
| if item['grp_key'] not in tree_terminated: |
| tree_terminated[item['grp_key']] = item['finished'] |
| elif item['finished'] == 2: |
| tree_terminated[item['grp_key']] = 2 |
| |
| _insert_into_result_table( |
| schema_madlib, tree_states, output_table_name, |
| grp_key_to_grp_cols, sample_id) |
| |
| _calculate_oob_prediction( |
| schema_madlib, output_table_name, cat_features_info_table, |
| con_splits_table, oob_prediction_table, oob_view, |
| sample_id, id_col_name, cat_features, con_features, |
| boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, |
| num_permutations, is_classification, importance, num_bins) |
| |
| ################################################################### |
| # evaluating and summerizing random forest |
| |
| oob_error_table = unique_string() |
| _calculate_oob_error(schema_madlib, oob_prediction_table, |
| oob_error_table, id_col_name, |
| is_classification) |
| |
| importance_table = unique_string() |
| sql_create_empty_imp_tbl = """ |
| CREATE TEMP TABLE {importance_table} |
| ( |
| gid integer, |
| cat_var_importance float8[], |
| con_var_importance float8[] |
| ); |
| """.format(**locals()) |
| plpy.notice("sql_create_empty_imp_tbl:\n"+sql_create_empty_imp_tbl) |
| plpy.execute(sql_create_empty_imp_tbl) |
| |
| # we populate the importance_table only if variable importance is to be |
| # calculated, otherwise we use an empty table which will be used later |
| # for an outer join. |
| if importance: |
| _calculate_variable_importance(schema_madlib, |
| oob_prediction_table, is_classification, |
| importance_table, len(cat_features), len(con_features)) |
| |
| _create_group_table(schema_madlib, output_table_name, |
| oob_error_table, importance_table, |
| cat_features_info_table, grp_key_to_grp_cols, |
| grouping_cols, tree_terminated) |
| |
| num_failed_groups = sum(1 for v in tree_terminated.values() if v != 1) |
| _create_summary_table(**locals()) |
| |
| sql_cleanup = """ |
| DROP TABLE {training_pois_cnt_table} CASCADE; |
| DROP TABLE {oob_prediction_table} CASCADE; |
| DROP TABLE {importance_table} CASCADE; |
| DROP TABLE {oob_error_table} CASCADE; |
| DROP TABLE {cat_features_info_table} CASCADE; |
| DROP TABLE {con_splits_table} CASCADE; |
| DROP TABLE {grp_key_to_grp_cols} CASCADE; |
| """.format(**locals()) |
| plpy.notice("sql_cleanup:\n" + sql_cleanup) |
| plpy.execute(sql_cleanup) |
| |
| return None |
| # ------------------------------------------------------------ |
| |
| |
| def forest_predict(schema_madlib, model, source, output, pred_type='response', |
| **kwargs): |
| """ |
| Args: |
| @param schema_madlib: str, Name of MADlib schema |
| @param model: str, Name of table containing the forest model |
| @param source: str, Name of table containing prediction data |
| @param output: str, Name of table to output the results |
| @param pred_type: str, The type of output required: |
| 'response' gives the actual response values, |
| 'prob' gives the probability of the classes in a |
| classification model. |
| For regression model, only type='response' is defined. |
| |
| Returns: |
| None |
| |
| Side effect: |
| Creates an output table containing the prediction for given source table |
| |
| Throws: |
| None |
| """ |
| pred_type = 'response' if pred_type is None or pred_type == '' else pred_type |
| _validate_predict(model, source, output, pred_type) |
| |
| model_summary = add_postfix(model, "_summary") |
| model_group = add_postfix(model, "_group") |
| # obtain the cat_features and con_features from model table |
| summary_elements = plpy.execute("SELECT * FROM {0}".format(model_summary))[0] |
| cat_features = split_quoted_delimited_str(summary_elements["cat_features"]) |
| con_features = split_quoted_delimited_str(summary_elements["con_features"]) |
| id_col_name = summary_elements["id_col_name"] |
| grouping_cols = summary_elements["grouping_cols"] |
| dep_varname = summary_elements["dependent_varname"] |
| dep_levels = summary_elements["dependent_var_levels"] |
| is_classification = summary_elements["is_classification"] |
| dep_type = summary_elements['dependent_var_type'] |
| |
| # pred_type='prob' is allowed only for classification |
| _assert(is_classification or pred_type == 'response', |
| "Random forest error: pred_type cannot be 'prob' for regression model.") |
| |
| # find which columns are of type boolean |
| boolean_cats = set([key for key, value in get_cols_and_types(source) |
| if value == 'boolean']) |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, boolean_cats, cat_features, con_features, |
| "cat_levels_in_text", "cat_n_levels") |
| |
| pred_name = ('"prob_{0}"' if pred_type == "prob" else |
| '"estimated_{0}"').format(dep_varname.replace('"', '').strip()) |
| |
| join_str = "," if grouping_cols is None else "JOIN" |
| using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")" |
| |
| if not is_classification: |
| majority_pred_expression = "avg(aggregated_prediction)" |
| else: |
| majority_pred_expression = """($sql${{ {dep_levels} }}$sql$::varchar[])[ |
| {schema_madlib}.mode(aggregated_prediction + 1)]::TEXT |
| """.format(**locals()) |
| |
| if dep_type.lower() == "boolean": |
| # some platforms don't have text to boolean cast. We manually check the string. |
| majority_pred_cast_str = ("(case {majority_pred_expression} when 'True' then " |
| "True else False end)::BOOLEAN as {pred_name}") |
| else: |
| majority_pred_cast_str = "{majority_pred_expression}::{dep_type} as {pred_name}" |
| |
| majority_pred_cast_str = majority_pred_cast_str.format(**locals()) |
| num_trees_grown = plpy.execute( |
| "SELECT count(distinct sample_id) FROM {model}" |
| .format(**locals()))[0]['count'] |
| |
| if pred_type == "response" or not is_classification: |
| sql_prediction = """ |
| CREATE TABLE {output} AS |
| SELECT |
| {id_col_name}, |
| {majority_pred_cast_str} |
| FROM |
| ( |
| SELECT |
| {id_col_name}, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[]) AS aggregated_prediction |
| FROM |
| {source} |
| {join_str} |
| {model_group} |
| {using_str} |
| JOIN |
| {model} |
| USING (gid) |
| ) prediction_agg |
| GROUP BY {id_col_name} |
| """.format(**locals()) |
| else: |
| len_dep_levels = len(split_quoted_delimited_str(dep_levels)) |
| normalized_majority_pred = unique_string() |
| score_format = ', \n'.join([ |
| '{temp}[{j}] as "estimated_prob_{c}"'. |
| format(j=i+1, c=c.strip(' "'), temp=normalized_majority_pred) |
| for i, c in enumerate(split_quoted_delimited_str(dep_levels))]) |
| |
| sql_prediction = """ |
| CREATE TABLE {output} AS |
| SELECT |
| {id_col_name}, |
| {score_format} |
| FROM |
| ( |
| SELECT |
| {id_col_name}, |
| {schema_madlib}.discrete_distribution_agg( |
| prediction::integer, |
| 1., |
| {len_dep_levels} |
| )::double precision[] |
| AS {normalized_majority_pred} |
| FROM |
| ( |
| SELECT |
| {id_col_name}, |
| gid, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[]) as prediction |
| FROM |
| {source} |
| {join_str} |
| {model_group} |
| {using_str} |
| JOIN |
| {model} |
| USING (gid) |
| ) class_prediction_subq |
| GROUP BY gid, {id_col_name} |
| ) subq |
| """.format(**locals()) |
| |
| with MinWarning('warning'): |
| with EnableOptimizer(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with EnableHashagg(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| plpy.notice("sql_prediction:\n"+sql_prediction) |
| plpy.execute(sql_prediction) |
| |
| return None |
| # ------------------------------------------------------------ |
| |
| |
| def get_tree_surr(schema_madlib, model_table, gid, sample_id, **kwargs): |
| return get_tree(schema_madlib, model_table, gid, sample_id, |
| dot_format=False, disp_surr=True) |
| |
| |
| def get_tree(schema_madlib, model_table, gid, sample_id, |
| dot_format=True, verbose=False, disp_surr=False, **kwargs): |
| """Random forest tree display function""" |
| |
| _validate_get_tree(model_table, gid, sample_id) |
| if dot_format: |
| disp_surr = False # surrogates cannot be displayed in dot format |
| bytea8 = schema_madlib + '.bytea8' |
| |
| model_table_summary = add_postfix(model_table, "_summary") |
| model_table_group = add_postfix(model_table, "_group") |
| summary = plpy.execute("SELECT * FROM {model_table_summary}". |
| format(model_table_summary=model_table_summary))[0] |
| dep_levels = summary["dependent_var_levels"] |
| dep_levels = [''] if not dep_levels else split_quoted_delimited_str(dep_levels) |
| table_name = summary["source_table"] |
| is_regression = not summary["is_classification"] |
| cat_features_str = split_quoted_delimited_str(summary['cat_features']) |
| con_features_str = split_quoted_delimited_str(summary['con_features']) |
| |
| with MinWarning('warning'): |
| sql_tree_result = """ |
| SELECT |
| tree, |
| cat_levels_in_text, |
| cat_n_levels |
| FROM |
| {model_table} |
| JOIN |
| {model_table_group} |
| USING (gid) |
| WHERE sample_id = {sample_id} |
| AND gid = {gid} |
| """.format(**locals()) |
| plpy.notice("sql_tree_result:\n"+sql_tree_result) |
| tree_result = plpy.execute(sql_tree_result) |
| |
| if not tree_result: |
| plpy.warning("no tree found by the given gid and sample_id, exiting...") |
| tree = tree_result[0] |
| |
| if tree['cat_levels_in_text']: |
| cat_levels_in_text = tree['cat_levels_in_text'] |
| cat_n_levels = tree['cat_n_levels'] |
| else: |
| cat_levels_in_text = [] |
| cat_n_levels = [] |
| |
| return_str_list = [] |
| if not disp_surr: |
| return_str_list.append(_get_display_header(table_name, dep_levels, |
| is_regression, dot_format)) |
| else: |
| return_str_list.append(""" |
| ------------------------------------- |
| Surrogates for internal nodes |
| ------------------------------------- |
| """) |
| |
| with MinWarning('warning'): |
| if disp_surr: |
| # Output only surrogate information for the internal nodes of tree |
| sql = """SELECT {0}._display_decision_tree_surrogate( |
| $1, $2, $3, $4, $5) as display_tree |
| """.format(schema_madlib) |
| # execute sql to get display string |
| sql_plan = plpy.prepare(sql, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]']) |
| tree_display = plpy.execute( |
| sql_plan, [tree['tree'], cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels])[0] |
| else: |
| # Output the splits in each node of the tree |
| if dot_format: |
| sql_display = """ |
| SELECT {0}._display_decision_tree( |
| $1, $2, $3, $4, $5, $6, '{1}', {2} |
| ) as display_tree |
| """.format(schema_madlib, "", verbose) |
| else: |
| sql_display = """ |
| SELECT {0}._display_text_decision_tree( |
| $1, $2, $3, $4, $5, $6 |
| ) as display_tree |
| """.format(schema_madlib) |
| |
| plpy.notice("sql_display:\n"+sql_display) |
| plan_display = plpy.prepare(sql_display, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]', 'text[]']) |
| tree_display = plpy.execute( |
| plan_display, [tree['tree'], cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels, |
| dep_levels])[0] |
| |
| return_str_list.append(tree_display["display_tree"]) |
| if dot_format: |
| return_str_list.append("} //---end of digraph--------- ") |
| return ("\n".join(return_str_list)) |
| # ------------------------------------------------------------ |
| |
| |
| def _calculate_oob_prediction( |
| schema_madlib, model_table, cat_features_info_table, con_splits_table, |
| oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, |
| con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, |
| num_permutations, is_classification, importance, num_bins): |
| """Calculate predication for out-of-bag sample""" |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, boolean_cats, cat_features, con_features, |
| "cat_levels_in_text", "cat_n_levels") |
| |
| join_str = "," if grouping_cols is None else "JOIN" |
| using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")" |
| |
| oob_var_dist_view = unique_string() |
| if importance: |
| sql_create_oob_var_dist_view = """ |
| CREATE VIEW {oob_var_dist_view} AS |
| SELECT |
| gid, |
| {schema_madlib}.vectorized_distribution_agg( |
| {schema_madlib}.array_scalar_add( |
| {cat_features_str}::integer[], |
| 1 -- -1 shift to 0 for nulls |
| ), |
| {schema_madlib}.array_scalar_add( |
| cat_n_levels, |
| 1 -- -1 shift to 0 for nulls |
| ) |
| ) AS cat_feature_distributions, |
| {schema_madlib}.vectorized_distribution_agg( |
| {schema_madlib}.array_scalar_add( |
| {schema_madlib}._get_bin_indices_by_values( |
| {con_features_str}::double precision[], |
| con_splits |
| ), -- bin_indices, -1 for NaN |
| 1 -- -1 shift to 0 for nulls |
| ), |
| {schema_madlib}.array_fill( |
| {schema_madlib}.array_of_float({n_con})::integer[], |
| ({num_bins}+1)::integer |
| ) |
| -- level of any continuous feature == num_bins |
| ) AS con_index_distributions |
| FROM |
| {oob_view} |
| {join_str} |
| {grp_key_to_grp_cols} |
| {using_str} |
| JOIN |
| {cat_features_info_table} |
| USING (gid) |
| JOIN |
| {con_splits_table} |
| USING (gid) |
| GROUP BY gid |
| """.format(n_con=len(con_features), **locals()) |
| else: |
| sql_create_oob_var_dist_view = """ |
| CREATE VIEW {oob_var_dist_view} AS |
| SELECT |
| gid, |
| NULL::float8[] AS cat_feature_distributions, |
| NULL::float8[] AS con_index_distributions |
| FROM {cat_features_info_table} |
| """.format(**locals()) |
| |
| plpy.notice("sql_create_oob_var_dist_view : " + str(sql_create_oob_var_dist_view)) |
| plpy.execute(sql_create_oob_var_dist_view) |
| |
| sql_oob_predict = """ |
| INSERT INTO {oob_prediction_table} |
| SELECT |
| {id_col_name}, |
| sample_id, |
| gid, |
| {dep} AS dep, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[] |
| ) AS oob_prediction, |
| {schema_madlib}._rf_cat_imp_score( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[], |
| cat_info.cat_n_levels, |
| {num_permutations}, |
| {dep}, |
| {is_classification}, |
| cat_feature_distributions -- if distribution is NULL, returns NULL |
| ) AS cat_imp_score, |
| {schema_madlib}._rf_con_imp_score( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[], |
| con_info.con_splits, |
| {num_permutations}, |
| {dep}, |
| {is_classification}, |
| con_index_distributions -- if distribution is NULL, returns NULL |
| ) AS con_imp_score |
| FROM |
| {oob_view} |
| {join_str} |
| {grp_key_to_grp_cols} |
| {using_str} |
| JOIN |
| ( |
| SELECT * |
| FROM {model_table} |
| WHERE sample_id = {sample_id} |
| ) m |
| USING (gid) |
| JOIN |
| {cat_features_info_table} cat_info |
| USING (gid) |
| JOIN |
| {con_splits_table} con_info |
| USING (gid) |
| LEFT OUTER JOIN -- empty if variable importance is disabled |
| {oob_var_dist_view} |
| USING (gid) |
| """.format(**locals()) |
| plpy.notice("sql_oob_predict : " + str(sql_oob_predict)) |
| plpy.execute(sql_oob_predict) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, |
| grp_key_to_grp_cols, bins): |
| |
| bytea8 = schema_madlib + '.bytea8' |
| bytea8arr = schema_madlib + '.bytea8[]' |
| if grouping_cols is None: |
| sql_create_con_splits_table = """ |
| CREATE TEMP TABLE {con_splits_table} AS |
| SELECT |
| 1 AS gid, |
| $1 AS con_splits |
| """.format(con_splits_table=con_splits_table) |
| plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table) |
| sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table, |
| [bytea8]) |
| plpy.execute(sql_create_con_splits_plan, [bins['con']]) |
| else: |
| sql_create_con_splits_table = """ |
| CREATE TABLE {con_splits_table} AS |
| SELECT |
| gid, |
| con_splits |
| FROM |
| {grp_key_to_grp_cols} |
| JOIN |
| ( |
| SELECT |
| unnest($1) as grp_key, |
| unnest($2) as con_splits |
| ) subq |
| USING (grp_key) |
| """.format(**locals()) |
| plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table) |
| sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table, |
| ['text[]', bytea8arr]) |
| plpy.execute(sql_create_con_splits_plan, |
| [bins['grp_key_con'], bins['con']]) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _calculate_variable_importance(schema_madlib, oob_prediction_table, |
| is_classification, importance_table, n_cat, n_con): |
| if not is_classification: |
| score_expression = "-((oob_prediction - dep)^2)".format(**locals()) |
| else: |
| score_expression = """ |
| CASE WHEN dep = oob_prediction::integer |
| THEN 1. |
| ELSE 0. |
| END""".format(**locals()) |
| |
| sample_score_view = unique_string() |
| sql_create_sample_score_view = """ |
| CREATE VIEW {sample_score_view} AS |
| SELECT |
| sample_id, |
| gid, |
| count(*) as size, |
| sum({score_expression}) as score, |
| {schema_madlib}.sum(cat_imp_score::FLOAT8[]) AS cat_imp_score, |
| {schema_madlib}.sum(con_imp_score::FLOAT8[]) AS con_imp_score |
| FROM |
| {oob_prediction_table} |
| GROUP BY sample_id, gid |
| """.format(**locals()) |
| plpy.notice("sql_create_sample_score_view:\n" + sql_create_sample_score_view) |
| plpy.execute(sql_create_sample_score_view) |
| |
| sql_create_importance_table = """ |
| INSERT INTO {importance_table} |
| SELECT |
| gid, |
| {schema_madlib}.array_avg( |
| {schema_madlib}.array_scalar_mult( |
| {schema_madlib}.array_scalar_add( |
| cat_imp_score, |
| -score::float8 |
| ), |
| (-1. / size)::float8 |
| ), |
| FALSE -- not use absolute values |
| ), |
| {schema_madlib}.array_avg( |
| {schema_madlib}.array_scalar_mult( |
| {schema_madlib}.array_scalar_add( |
| con_imp_score, |
| -score::float8 |
| ), |
| (-1. / size)::float8 |
| ), |
| FALSE -- not use absolute values |
| ) |
| FROM |
| {sample_score_view} |
| GROUP BY gid |
| """.format(**locals()) |
| plpy.notice("sql_create_importance_table:\n" + sql_create_importance_table) |
| plpy.execute(sql_create_importance_table) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _calculate_oob_error(schema_madlib, oob_prediction_table, oob_error_table, |
| id_col_name, is_classification): |
| """Calculate out-of-bag error for oob samples""" |
| if not is_classification: |
| residual_expression = "(dep - forest_prediction)^2".format(**locals()) |
| forest_prediction_agg = 'avg' |
| else: |
| residual_expression = """ |
| CASE WHEN dep = forest_prediction::integer |
| THEN 0. |
| ELSE 1. |
| END""".format(**locals()) |
| forest_prediction_agg = "{schema_madlib}.mode".format(**locals()) |
| |
| sql_compute_oob_error = """ |
| CREATE TABLE {oob_error_table} AS |
| SELECT |
| gid, |
| avg({residual_expression}) AS oob_error |
| FROM |
| ( |
| SELECT |
| gid, |
| dep, |
| {forest_prediction_agg}(oob_prediction) AS forest_prediction |
| FROM |
| {oob_prediction_table} |
| GROUP BY gid, {id_col_name}, dep |
| ) prediction_subq |
| GROUP BY gid |
| """.format(**locals()) |
| plpy.notice("sql_compute_oob_error : " + str(sql_compute_oob_error)) |
| plpy.execute(sql_compute_oob_error) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_summary_table(**kwargs): |
| kwargs['features'] = ','.join(kwargs['cat_features'] + kwargs['con_features']) |
| if kwargs['dep_list']: |
| kwargs['dep_list_str'] = ( |
| "$dep_list$" + |
| ','.join('"{0}"'.format(str(dep)) for dep in kwargs['dep_list']) + |
| "$dep_list$") |
| else: |
| kwargs['dep_list_str'] = "NULL" |
| |
| kwargs['indep_type'] = ', '.join(_dict_get_quoted(kwargs['all_cols_types'], col) |
| for col in kwargs['cat_features'] + kwargs['con_features']) |
| kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'], |
| kwargs['dependent_variable']) |
| kwargs['cat_features_str'] = ','.join(kwargs['cat_features']) |
| kwargs['con_features_str'] = ','.join(kwargs['con_features']) |
| if kwargs['grouping_cols']: |
| kwargs['grouping_cols_str'] = "'{grouping_cols}'".format(**kwargs) |
| else: |
| kwargs['grouping_cols_str'] = 'NULL' |
| kwargs['n_rows_skipped'] = kwargs['n_all_rows'] - kwargs['n_rows'] |
| |
| kwargs['output_table_summary'] = add_postfix(kwargs['output_table_name'], "_summary") |
| sql_create_summary_table = """ |
| CREATE TABLE {output_table_summary} AS |
| SELECT |
| 'forest_train'::text AS method, |
| '{is_classification}'::boolean AS is_classification, |
| '{training_table_name}'::text AS source_table, |
| '{output_table_name}'::text AS model_table, |
| '{id_col_name}'::text AS id_col_name, |
| '{dependent_variable}'::text AS dependent_varname, |
| '{features}'::text AS independent_varnames, |
| '{cat_features_str}'::text AS cat_features, |
| '{con_features_str}'::text AS con_features, |
| {grouping_cols_str}::text AS grouping_cols, |
| {num_trees}::integer AS num_trees, |
| {num_random_features}::integer AS num_random_features, |
| {max_tree_depth}::integer AS max_tree_depth, |
| {min_split}::integer AS min_split, |
| {min_bucket}::integer AS min_bucket, |
| {num_bins}::integer AS num_splits, |
| {verbose}::boolean AS verbose, |
| {importance}::boolean AS importance, |
| {num_permutations}::integer AS num_permutations, |
| {num_groups}::integer AS num_all_groups, |
| {num_failed_groups}::integer AS num_failed_groups, |
| {n_rows}::integer AS total_rows_processed, |
| {n_rows_skipped}::integer AS total_rows_skipped, |
| {dep_list_str}::text AS dependent_var_levels, |
| '{dep_type}'::text AS dependent_var_type, |
| '{indep_type}'::text AS independent_var_types |
| """.format(**kwargs) |
| plpy.notice("sql_create_summary_table:\n" + sql_create_summary_table) |
| plpy.execute(sql_create_summary_table) |
| # ------------------------------------------------------------ |
| |
| |
| def _create_group_table( |
| schema_madlib, output_table_name, oob_error_table, |
| importance_table, cat_features_info_table, grp_key_to_grp_cols, |
| grouping_cols, tree_terminated): |
| """ Ceate the group table for random forest""" |
| grouping_cols_str = ('' if grouping_cols is None |
| else grouping_cols + ",") |
| group_table_name = add_postfix(output_table_name, "_group") |
| sql_create_group_table = """ |
| CREATE TABLE {group_table_name} AS |
| SELECT |
| gid, |
| {grouping_cols_str} |
| grp_finished as success, |
| cat_n_levels, |
| cat_levels_in_text, |
| oob_error, |
| cat_var_importance, |
| con_var_importance |
| FROM |
| {oob_error_table} |
| JOIN |
| {grp_key_to_grp_cols} |
| USING (gid) |
| JOIN ( |
| SELECT |
| unnest($1) as grp_key, |
| unnest($2) as grp_finished |
| ) tree_terminated |
| USING (grp_key) |
| JOIN |
| {cat_features_info_table} |
| USING (gid) |
| LEFT OUTER JOIN |
| {importance_table} |
| USING (gid) |
| """.format(**locals()) |
| plpy.notice("sql_create_group_table:\n" + sql_create_group_table) |
| plan_create_group_table = plpy.prepare(sql_create_group_table, |
| ['text[]', 'boolean[]']) |
| plpy.execute(plan_create_group_table, |
| [tree_terminated.keys(), |
| [True if v == 1 else False for v in tree_terminated.values()]]) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_empty_result_table(schema_madlib, output_table_name): |
| """Create the result table for all trees in the forest""" |
| sql_create_empty_result_table = """ |
| CREATE TABLE {output_table_name} ( |
| gid integer, |
| sample_id integer, |
| tree {schema_madlib}.bytea8); |
| """.format(**locals()) |
| plpy.notice("sql_create_empty_result_table:\n" + sql_create_empty_result_table) |
| plpy.execute(sql_create_empty_result_table) |
| # ------------------------------------------------------------ |
| |
| |
| def _insert_into_result_table(schema_madlib, tree_states, output_table_name, |
| grp_key_to_grp_cols, sample_id): |
| """Insert one tree to result table""" |
| sql = """ |
| INSERT INTO {output_table_name} |
| SELECT |
| gid, |
| {sample_id} AS sample_id, |
| tree |
| FROM |
| ( |
| SELECT |
| unnest($1) AS grp_key, |
| unnest($2) AS tree |
| ) grp_key_to_tree |
| JOIN |
| {grp_key_to_grp_cols} |
| USING (grp_key) |
| """.format(**locals()) |
| sql_plan = plpy.prepare(sql, ['text[]', '{0}.bytea8[]'.format(schema_madlib)]) |
| plpy.execute(sql_plan, [ |
| [tree_state['grp_key'] for tree_state in tree_states], |
| [tree_state['tree_state'] for tree_state in tree_states]]) |
| # ------------------------------------------------------------ |
| |
| |
| def _forest_validate_args( |
| training_table_name, output_table_name, id_col_name, |
| list_of_features, dependent_variable, list_of_features_to_exclude, |
| grouping_cols, num_trees, num_random_features, n_perm, |
| max_tree_depth, min_split, min_bucket, num_bins, sample_ratio): |
| """ Validate the arguments for the random forest training function""" |
| |
| input_tbl_valid(training_table_name, 'Random forest') |
| cols_in_tbl_valid(training_table_name, [id_col_name], 'Random forest') |
| |
| output_tbl_valid(output_table_name, 'Random forest') |
| output_tbl_valid(add_postfix(output_table_name, "_group"), 'Random forest') |
| output_tbl_valid(add_postfix(output_table_name, "_summary"), 'Random forest') |
| |
| _assert(not (list_of_features is None or list_of_features.strip().lower() == ''), |
| "Random forest error: Features to include is empty.") |
| if list_of_features.strip() != '*': |
| _assert(is_var_valid(training_table_name, list_of_features), |
| "Random forest error: Invalid feature list ({0})". |
| format(list_of_features)) |
| _assert(not (dependent_variable is None or dependent_variable.strip().lower() == ''), |
| "Random forest error: Dependent variable is empty.") |
| _assert(is_var_valid(training_table_name, dependent_variable), |
| "Random forest error: Invalid dependent variable ({0}).". |
| format(dependent_variable)) |
| |
| if grouping_cols is not None and grouping_cols.strip() != '': |
| _assert(is_var_valid(training_table_name, grouping_cols), |
| "Random forest error: Invalid grouping column argument.") |
| |
| _assert(num_trees > 0, "Random forest error: num_trees must be positive.") |
| _assert(n_perm > 0, "Random forest error: num_permutations must be positive.") |
| if num_random_features is not None: |
| _assert(num_random_features > 0, |
| "Random forest error: num_random_features must be positive.") |
| _assert(max_tree_depth >= 0 and max_tree_depth <= 15, |
| "Random forest error: max_tree_depth must be non-negative and less than 16.") |
| _assert(min_split > 0, "Random forest error: min_split must be positive.") |
| _assert(min_bucket > 0, "Random forest error: min_bucket must be positive.") |
| _assert(num_bins > 1, "Random forest error: number of bins must be at least 2.") |
| _assert(sample_ratio > 0 and sample_ratio <= 1, |
| "Random forest error: sample_ratio must be in (0, 1].") |
| # ------------------------------------------------------------ |
| |
| |
| def _validate_predict(model, source, output, pred_type): |
| """Validations for input arguments""" |
| input_tbl_valid(model, 'Random forest') |
| cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest') |
| |
| input_tbl_valid(add_postfix(model, "_group"), 'Random forest') |
| cols_in_tbl_valid(add_postfix(model, "_group"), |
| ['gid', 'cat_n_levels', 'cat_levels_in_text'], |
| 'Random forest') |
| |
| input_tbl_valid(add_postfix(model, "_summary"), 'Random forest') |
| cols_in_tbl_valid(add_postfix(model, "_summary"), |
| ["grouping_cols", "id_col_name", "dependent_varname", |
| "cat_features", "con_features", "is_classification"], |
| 'Random forest') |
| |
| input_tbl_valid(source, 'Random forest') |
| |
| output_tbl_valid(output, 'Random forest') |
| |
| _assert(pred_type in ('response', 'prob'), |
| "Random forest error: pred_type should be 'response' or 'prob'") |
| # ------------------------------------------------------------ |
| |
| |
| def _validate_get_tree(model, gid, sample_id): |
| """Validations for input arguments""" |
| input_tbl_valid(model, 'Random forest') |
| cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest') |
| |
| # ------------------------------------------------------------ |
| |
| def forest_predict_help_message(schema_madlib, message, **kwargs): |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Random Forest Prediction |
| |
| Random forests use a forest-based predictive model to predict |
| the value of a target variable based on several input variables. |
| This is the function to make predictions using the model trained |
| by the function 'forest_train'. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.forest_predict('usage'); |
| For an example on using this function: |
| SELECT {schema_madlib}.forest_predict('example'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.forest_predict( |
| 'forest_model', -- Model table name (output of forest_train) |
| 'new_data_table', -- Source data table |
| 'output_table', -- The name of the table storing the predictions |
| 'type' -- Type of prediction output, 'response' or 'prob' |
| ); |
| |
| Note: The 'new_data_table' should have the same 'id_col_name' column as used |
| in the training function. This is used to corelate the prediction data row with |
| the actual prediction in the output table. |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the '<id_col_name>' column giving |
| the 'id' for each prediction and the prediction columns for the response |
| variable (also called as dependent variable). |
| |
| If prediction type = 'response', then the table has a single column with the |
| prediction value of the response. The type of this column depends on the type |
| of the response variable used during training. The response value for regression |
| is the average prediction across all trees, and is the majority vote in |
| the case of classification. |
| |
| If prediction type = 'prob', then the table has multiple columns, one for each |
| possible value of the response variable. The columns are labeled as |
| 'estimated_prob_<dep value>', where <dep value> represents for each value |
| of the response. This is only for the classification models, and the value |
| is the fraction of votes in each category. |
| |
| """ |
| elif message.lower().strip() in ['example', 'examples']: |
| help_string = """ |
| ------------------------------------------------------------ |
| EXAMPLE |
| ------------------------------------------------------------ |
| -- Assuming the example of forest_train has been run |
| SELECT {schema_madlib}.forest_predict( |
| 'forest_out', |
| 'dummy_dt_src', |
| 'forest_predict_out', |
| 'response' |
| ); |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.forest_predict('usage')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |