| # coding=utf-8 |
| """ |
| @file random_forest.py_in |
| |
| @brief Random Forest: Driver functions |
| |
| @namespace random_forest |
| """ |
| |
| import plpy |
| from math import sqrt, ceil |
| |
| from internal.db_utils import quote_literal |
| |
| from utilities.control import MinWarning |
| from utilities.control import OptimizerControl |
| from utilities.control import HashaggControl |
| from utilities.utilities import _assert |
| from utilities.utilities import add_postfix |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import is_psql_boolean_type |
| from utilities.utilities import py_list_to_sql_string |
| from utilities.utilities import split_quoted_delimited_str |
| from utilities.utilities import unique_string |
| |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import get_cols_and_types |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import is_var_valid |
| from utilities.validate_args import output_tbl_valid |
| from utilities.validate_args import table_exists |
| |
| from decision_tree import _tree_train_using_bins |
| from decision_tree import _tree_train_grps_using_bins |
| from decision_tree import _get_bins |
| from decision_tree import _get_bins_grps |
| from decision_tree import _get_features_to_use |
| from decision_tree import _is_dep_categorical |
| from decision_tree import _get_n_and_deplist |
| from decision_tree import _classify_features |
| from decision_tree import _get_filter_str |
| from decision_tree import _get_display_header |
| from decision_tree import get_feature_str |
| from decision_tree import get_grouping_array_str |
| from decision_tree import _compute_var_importance |
| from decision_tree import _create_cat_features_info_table |
| # ------------------------------------------------------------ |
| |
| |
| def forest_train_help_message(schema_madlib, message, **kwargs): |
| """ Help message for Random Forest |
| """ |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Random Forest |
| |
| Random forests use a forest-based predictive model to |
| predict the value of a target variable based on several input variables. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.forest_train('usage'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.forest_train( |
| 'training_table', -- Data table name |
| 'output_table', -- Table name to store the forest model |
| 'id_col_name', -- Row ID, used in forest_predict |
| 'dependent_variable', -- The column to fit |
| 'list_of_features', -- Comma separated column names to be |
| used as the predictors, can be '*' |
| to include all columns except the |
| dependent_variable |
| 'features_to_exclude', -- Comma separated column names to be |
| excluded if list_of_features is '*' |
| 'grouping_cols', -- Comma separated column names used to |
| group the data. A random forest model |
| will be created for each group. Default |
| is NULL |
| num_trees, -- Integer, default: 100. Maximum number of trees |
| to grow in the Random Forest model. Actual |
| number of trees grown may be slighlty different. |
| num_random_features, -- Integer, default: sqrt(n) if classification tree, |
| otherwise n/3. Number of features to randomly |
| select at each split. |
| importance, -- Boolean, whether to calculate variable importance, |
| default is True |
| num_permutations -- Number of times to permute each feature value while |
| calculating variable importance, default is 1. |
| max_tree_depth, -- Maximum depth of any node, default is 10 |
| min_split, -- Minimum number of observations that must |
| exist in a node for a split to be |
| attemped, default is 20 |
| min_bucket, -- Minimum number of observations in any |
| terminal node, default is min_split/3 |
| num_splits, -- Number of bins to find possible node |
| split threshold values for continuous |
| variables, default is 100 (Must be greater than 1) |
| 'surrogate_params', -- Text, Comma-separated string of key-value pairs |
| controlling the behavior of surrogate splits for |
| each node in a tree. 'max_surrogates=n', where n |
| is an positive integer with the default 0. |
| verbose, -- Boolean, whether to print more info, |
| default is False |
| sample_ratio -- Double precision, in the range of (0, 1], default: 1 |
| If sample_ratio is less than 1, a bootstrap sample |
| size smaller than the data table is expected to be |
| used for training each tree in the forest. |
| ); |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the following columns ( |
| quoted items are of text type.): |
| gid -- Integer. group id that uniquely identifies |
| a set of grouping column values. |
| sample_id -- Integer. id of the bootstrap sample that |
| this tree is a part of. |
| tree -- bytea8. trained tree model stored in binary |
| format. |
| |
| The output summary table ('output_table_summary') has the following |
| columns: |
| 'method' -- Method name: 'forest_train' |
| is_classification -- boolean. True if it is a classification model. |
| 'source_table' -- Data table name |
| 'model_table' -- Tree model table name |
| 'id_col_name' -- The ID column name |
| 'dependent_varname' -- Response variable column name |
| 'independent_varnames' -- Comma-separated feature column names |
| 'cat_features' -- Comma-separated column names of categorical variables |
| 'con_features' -- Comma-separated column names of continuous variables |
| 'grouping_cols' -- Grouping column names |
| num_trees -- Number of trees grown by the model |
| num_random_features -- Number of features randomly selected for each split. |
| max_tree_depth -- Maximum depth of any node. |
| min_split -- Minimum number of observations that must |
| exist in a node for a split to be |
| attemped. |
| min_bucket -- Minimum number of observations in any |
| terminal node. |
| num_splits -- Number of bins to find possible node |
| split threshold values for continuous |
| variables. |
| verbose -- Boolean, whether to print more info. |
| importance -- Boolean, whether to calculate variable importance, |
| num_permutations -- Number of times to permute each feature value while |
| calculating variable importance |
| num_all_groups -- Int. Number of groups during forest training. |
| num_failed_groups -- Number of groups for which training failed |
| total_rows_processed -- Number of rows used in the model training |
| total_rows_skipped -- Number of rows skipped because NULL values |
| dependent_var_levels -- For classification, the distinct levels of |
| the dependent variable |
| dependent_var_type -- The type of dependent variable |
| |
| A third table contains the information about the grouping ('output_table_group') and |
| it has the following columns: |
| gid -- Integer. group id that uniquely identifies a set |
| of grouping column values. |
| <...> -- Grouping columns, if provided in input. Same type as |
| in the training data table. This could be multiple |
| columns depending on the grouping_cols input. |
| success -- Boolean, indicator of the success of the group. |
| cat_levels_in_text -- text[]. Ordered levels of categorical variables. |
| cat_n_levels -- integer[]. Number of levels for each categorical variable. |
| oob_error -- double precision. Out-of-bag error for the random forest model. |
| cat_var_importance -- double precision[]. Variable importance for categorical |
| features. The order corresponds to the order of the |
| variables as found in cat_features in <output_table>_summary. |
| con_var_importance -- double precision[]. Variable importance for continuous |
| features. The order corresponds to the order of the |
| variables as found in con_features in <model_table>_summary. |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.forest_train('usage')" |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |
| |
| |
| def forest_train( |
| schema_madlib, training_table_name, output_table_name, id_col_name, |
| dependent_variable, list_of_features, list_of_features_to_exclude, |
| grouping_cols, num_trees, num_random_features, |
| importance, num_permutations, max_tree_depth, |
| min_split, min_bucket, num_bins, |
| null_handling_params, verbose=False, sample_ratio=None, **kwargs): |
| """ Random forest main training function |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param training_table_name: str, source table name |
| @param output_table_name: str, model table name |
| @param id_col_name: str, id column name to uniquely identify each row |
| @param dependent_variable: str, dependent variable column name |
| @param list_of_features: str, Comma-separated list of feature column names, |
| can also be '*' implying all columns |
| except dependent_variable |
| @param list_of_features_to_exclude: str, features to exclude if '*' is used |
| @param grouping_cols: str, List of grouping columns to group the data |
| @param num_trees: int, Number of trees in the forest |
| @param num_random_features: int, Number of random features used in spliting nodes |
| @param importance: boolean, Whether or not to calculate variable importance |
| @param num_permutations: int, Number of times to permute each feature value |
| during calculation of variable importance |
| @param max_tree_depth: int, Maximum depth of each tree |
| @param min_split: int, Minimum tuples in a node before splitting it |
| @param min_bucket: int, Minimum tuples in each child before splitting a node |
| @param num_bins: int, Number of bins for quantizing a continuous variables |
| @param verbose: str, Verbosity of output messages |
| @param sample_ratio: float, subsampling ratio for generating src_view |
| """ |
| msg_level = "notice" if verbose else "warning" |
| with MinWarning(msg_level): |
| with OptimizerControl(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with HashaggControl(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| # set default values |
| if grouping_cols is not None and grouping_cols.strip() == '': |
| grouping_cols = None |
| num_trees = 100 if num_trees is None else num_trees |
| max_tree_depth = 10 if max_tree_depth is None else max_tree_depth |
| min_split = 20 if min_split is None and min_bucket is None else min_split |
| min_bucket = min_split // 3 if not min_bucket else min_bucket |
| min_split = min_bucket * 3 if not min_split else min_split |
| num_bins = 100 if num_bins is None else num_bins |
| sample_ratio = 1 if sample_ratio is None else sample_ratio |
| |
| null_handling_dict = extract_keyvalue_params( |
| null_handling_params, |
| dict(max_surrogates=int, null_as_category=bool), |
| dict(max_surrogates=0, null_as_category=False)) |
| max_n_surr = null_handling_dict['max_surrogates'] |
| null_as_category = null_handling_dict['null_as_category'] |
| null_proxy = "__NULL__" if null_as_category else None |
| if null_as_category: |
| # can't have two ways of handling tuples with NULL values |
| max_n_surr = 0 |
| _assert(max_n_surr >= 0, |
| "Maximum number of surrogates ({0}) should be non-negative". |
| format(max_n_surr)) |
| ################################################################## |
| # validate arguments |
| _forest_validate_args(training_table_name, output_table_name, id_col_name, |
| list_of_features, dependent_variable, |
| list_of_features_to_exclude, grouping_cols, |
| num_trees, num_random_features, |
| num_permutations, max_tree_depth, |
| min_split, min_bucket, num_bins, sample_ratio) |
| |
| ################################################################## |
| # preprocess arguments |
| # expand "*" syntax and exclude some features |
| features = _get_features_to_use(schema_madlib, |
| training_table_name, |
| list_of_features, |
| list_of_features_to_exclude, |
| id_col_name, |
| '1', dependent_variable, |
| grouping_cols) |
| |
| _assert(bool(features), |
| "Random forest error: No feature is selected for the model.") |
| |
| is_classification, dep_is_bool = _is_dep_categorical( |
| training_table_name, dependent_variable) |
| split_criterion = 'gini' if is_classification else 'mse' |
| |
| if num_random_features is None: |
| n_all_features = len(features) |
| num_random_features = int(sqrt(n_all_features) if is_classification |
| else ceil(float(n_all_features) / 3)) |
| |
| _assert(0 < num_random_features <= len(features), |
| "Random forest error: Number of features to be selected " |
| "is more than the actual number of features.") |
| |
| all_cols_types = dict([(f, get_expr_type(f, training_table_name)) |
| for f in features]) |
| cat_features, ordered_cat_features, boolean_cats, con_features = \ |
| _classify_features(all_cols_types, features) |
| |
| filter_null = _get_filter_str(dependent_variable, grouping_cols) |
| # the total number of records |
| n_all_rows = plpy.execute("SELECT count(*) FROM {0}". |
| format(training_table_name))[0]['count'] |
| |
| if is_classification: |
| # For classifications, we also need to map dependent_variable to integers |
| n_rows, dep_list = _get_n_and_deplist(training_table_name, |
| dependent_variable, |
| filter_null) |
| dep_n_levels = len(dep_list) |
| _assert(n_rows > 0, |
| "Random forest error: There should be at least one " |
| "data point for each class where all features are non NULL") |
| if dep_is_bool: |
| # false = 0, true = 1 |
| # This order is maintained in dep_list since |
| # _get_n_and_deplist returns a sorted list |
| dep = ("(CASE WHEN {0} THEN 1 ELSE 0 END)". |
| format(dependent_variable)) |
| else: |
| dep = ("(CASE " + |
| "\n\t\t".join(["WHEN ({0})::text = $${1}$$ THEN {2}". |
| format(dependent_variable, c, i) |
| for i, c in enumerate(dep_list)]) + |
| "\nEND)") |
| else: |
| n_rows = plpy.execute( |
| "SELECT count(*) FROM {0} WHERE {1}". |
| format(training_table_name, filter_null))[0]['count'] |
| dep = dependent_variable |
| dep_n_levels = 1 |
| dep_list = None |
| |
| # a table that maps gid/grp_key to actual columns |
| grp_key_to_grp_cols = unique_string() |
| # create the above table and perform binning |
| if grouping_cols is None: |
| sql_grp_key_to_grp_cols = """ |
| CREATE TABLE {grp_key_to_grp_cols} AS |
| SELECT ''::text AS grp_key, 1 AS gid |
| """.format(**locals()) |
| plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols) |
| plpy.execute(sql_grp_key_to_grp_cols) |
| |
| # find the bins, one dict containing two arrays: categorical |
| # bins, and continuous bins |
| num_groups = 1 |
| bins = _get_bins(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, num_bins, dep, |
| boolean_cats, n_rows, is_classification, |
| dep_n_levels, filter_null, null_proxy) |
| # some features may be dropped because they have only one value |
| cat_features = bins['cat_features'] |
| bins['grp_key_cat'] = [''] |
| else: |
| grouping_array_str = get_grouping_array_str( |
| training_table_name, grouping_cols) |
| grouping_cols_str = ('' if grouping_cols is None |
| else grouping_cols + ",") |
| sql_grp_key_to_grp_cols = """ |
| CREATE TABLE {grp_key_to_grp_cols} AS |
| SELECT |
| {grouping_cols}, |
| {grouping_array_str} AS grp_key, |
| (row_number() over ())::integer AS gid |
| FROM {training_table_name} |
| GROUP BY {grouping_cols} |
| """.format(**locals()) |
| plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols) |
| plpy.execute(sql_grp_key_to_grp_cols) |
| |
| # find bins |
| num_groups = plpy.execute(""" |
| SELECT count(*) FROM {grp_key_to_grp_cols} |
| """.format(**locals()))[0]['count'] |
| plpy.notice("Analyzing data to compute split boundaries for variables") |
| bins = _get_bins_grps(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, num_bins, dep, |
| boolean_cats, grouping_cols, |
| grouping_array_str, n_rows, |
| is_classification, dep_n_levels, |
| filter_null, null_proxy) |
| cat_features = bins['cat_features'] |
| |
| # a table for getting information of cat features for each group |
| cat_features_info_table = unique_string() |
| _create_cat_features_info_table(cat_features_info_table, bins) |
| |
| con_splits_table = unique_string() |
| _create_con_splits_table(schema_madlib, con_splits_table, |
| grouping_cols, grp_key_to_grp_cols, bins) |
| |
| ################################################################## |
| # create views and tables for training (growing) of trees |
| # store the prediction for all oob samples |
| # for classification, the prediction is of integer type here |
| oob_prediction_table = unique_string() |
| sql_create_oob_prediction_table = """ |
| CREATE TEMP TABLE {oob_prediction_table} AS |
| SELECT |
| {id_col_name}, |
| 1 AS sample_id, |
| 1 AS gid, |
| {dep} AS dep, |
| {dep} AS oob_prediction, |
| ARRAY[1.0]::float8[] AS cat_permuted_imp_score, |
| ARRAY[1.0]::float8[] AS con_permuted_imp_score |
| FROM {training_table_name} |
| LIMIT 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_oob_prediction_table:\n" + sql_create_oob_prediction_table) |
| plpy.execute(sql_create_oob_prediction_table) |
| |
| # to store poisson count defining bootstrap sample |
| training_pois_cnt_table = unique_string() |
| subsample_random_column = unique_string() |
| sql_create_training_pois_cnt = """ |
| CREATE TEMP TABLE {training_pois_cnt_table} AS |
| SELECT |
| *, |
| 1.::double precision, |
| {schema_madlib}.poisson_random(1) AS poisson_count |
| FROM {training_table_name} |
| LIMIT 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_training_pois_cnt:\n" + sql_create_training_pois_cnt) |
| plpy.execute(sql_create_training_pois_cnt) |
| |
| # views dependent on current bootstrap sample |
| src_view = unique_string() |
| sql_create_src_view = """ |
| CREATE VIEW {src_view} AS |
| SELECT * |
| FROM {training_pois_cnt_table} |
| WHERE poisson_count != 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_src_view:\n" + sql_create_src_view) |
| plpy.execute(sql_create_src_view) |
| |
| oob_view = unique_string() |
| sql_create_oob_view = """ |
| CREATE VIEW {oob_view} AS |
| SELECT * |
| FROM {training_pois_cnt_table} |
| WHERE poisson_count = 0 |
| """.format(**locals()) |
| plpy.notice("sql_create_oob_view:\n" + sql_create_oob_view) |
| plpy.execute(sql_create_oob_view) |
| if importance: |
| impurity_imp_table = unique_string(desp='temp_out') |
| else: |
| impurity_imp_table = '' |
| _create_empty_result_table(schema_madlib, output_table_name, |
| impurity_imp_table, importance) |
| |
| ################################################################## |
| # training random forest |
| tree_terminated = dict() |
| for sample_id in range(1, num_trees + 1): |
| if 1 - sample_ratio < 1e-6: |
| random_sample_expr = "0.::double precision" |
| else: |
| random_sample_expr = "random()" |
| |
| sql_refresh_training_pois_cnt = """ |
| TRUNCATE TABLE {training_pois_cnt_table} CASCADE; |
| INSERT INTO {training_pois_cnt_table} |
| SELECT |
| *, |
| {schema_madlib}.poisson_random(1) AS poisson_count |
| FROM |
| ( |
| SELECT |
| *, |
| {random_sample_expr} AS {subsample_random_column} |
| FROM {training_table_name} |
| ) subq |
| WHERE {subsample_random_column} < {sample_ratio} |
| """.format(**locals()) |
| plpy.notice("sql_refresh_training_pois_cnt:\n" + sql_refresh_training_pois_cnt) |
| plpy.execute(sql_refresh_training_pois_cnt) |
| |
| if verbose: |
| tup_cnt_in_view = plpy.execute(""" |
| SELECT |
| count(*) AS c, |
| sum(poisson_count) AS s |
| FROM {src_view} |
| """.format(**locals()))[0] |
| src_cnt = tup_cnt_in_view['c'] |
| dup_cnt = tup_cnt_in_view['s'] |
| oob_cnt = plpy.execute(""" |
| SELECT count(*) AS c FROM {oob_view} |
| """.format(**locals()))[0]['c'] |
| plpy.notice(""" |
| src_cnt: {src_cnt}, |
| oob_cnt: {oob_cnt}, |
| dup_cnt: {dup_cnt}. |
| """.format(**locals())) |
| |
| if not grouping_cols: |
| tree = _tree_train_using_bins( |
| schema_madlib, bins, src_view, cat_features, con_features, |
| boolean_cats, num_bins, 'poisson_count', dep, min_split, |
| min_bucket, max_tree_depth, filter_null, dep_n_levels, |
| is_classification, split_criterion, True, |
| num_random_features, max_n_surr, null_proxy) |
| |
| tree['grp_key'] = '' |
| if importance: |
| tree.update(_compute_var_importance( |
| schema_madlib, tree, |
| len(cat_features), len(con_features))) |
| tree_states = [tree] |
| tree_terminated = {'': tree['finished']} |
| else: |
| tree_states = _tree_train_grps_using_bins( |
| schema_madlib, bins, src_view, cat_features, con_features, |
| boolean_cats, num_bins, 'poisson_count', grouping_cols, |
| grouping_array_str, dep, min_split, min_bucket, |
| max_tree_depth, filter_null, dep_n_levels, |
| is_classification, split_criterion, |
| cat_features_info_table, |
| subsample=True, |
| n_random_features=num_random_features, |
| tree_terminated=tree_terminated, |
| max_n_surr=max_n_surr, null_proxy=null_proxy) |
| |
| # If a tree for a group is terminated (not finished properly), |
| # then we do not need to compute other trees, and can just |
| # stop calculating that group further. |
| for tree in tree_states: |
| grp_key = tree['grp_key'] |
| tree_terminated[grp_key] = tree['finished'] |
| if importance: |
| importance_vectors = _compute_var_importance( |
| schema_madlib, tree, |
| len(cat_features), |
| len(con_features)) |
| tree.update(**importance_vectors) |
| |
| _insert_into_result_table( |
| schema_madlib, tree_states, output_table_name, impurity_imp_table, |
| grp_key_to_grp_cols, sample_id, importance, grouping_cols) |
| |
| _calculate_oob_prediction( |
| schema_madlib, output_table_name, cat_features_info_table, |
| con_splits_table, oob_prediction_table, oob_view, |
| sample_id, id_col_name, cat_features, con_features, |
| training_table_name, grouping_cols, grp_key_to_grp_cols, dep, |
| num_permutations, is_classification, importance, |
| num_bins, filter_null, null_proxy) |
| |
| ################################################################### |
| # evaluating and summarizing random forest |
| |
| oob_error_table = unique_string() |
| _calculate_oob_error(schema_madlib, oob_prediction_table, |
| oob_error_table, id_col_name, |
| is_classification) |
| if importance: |
| importance_table = unique_string() |
| _calculate_oob_variable_importance( |
| schema_madlib, oob_prediction_table, is_classification, |
| importance_table, len(cat_features), len(con_features)) |
| else: |
| importance_table = '' |
| |
| _create_group_table(schema_madlib, |
| output_table_name, |
| impurity_imp_table, |
| oob_error_table, |
| importance_table, |
| cat_features_info_table, |
| grp_key_to_grp_cols, |
| grouping_cols, |
| tree_terminated) |
| |
| num_failed_groups = sum(1 for v in tree_terminated.values() if v != 1) |
| _create_summary_table(**locals()) |
| |
| sql_cleanup = """ |
| DROP TABLE IF EXISTS {training_pois_cnt_table} CASCADE; |
| DROP TABLE IF EXISTS {oob_prediction_table} CASCADE; |
| DROP TABLE IF EXISTS {importance_table} CASCADE; |
| DROP TABLE IF EXISTS {impurity_imp_table} CASCADE; |
| DROP TABLE IF EXISTS {oob_error_table} CASCADE; |
| DROP TABLE IF EXISTS {cat_features_info_table} CASCADE; |
| DROP TABLE IF EXISTS {con_splits_table} CASCADE; |
| DROP TABLE IF EXISTS {grp_key_to_grp_cols} CASCADE; |
| """.format(**locals()) |
| plpy.notice("sql_cleanup:\n" + sql_cleanup) |
| plpy.execute(sql_cleanup) |
| |
| return None |
| # ------------------------------------------------------------ |
| |
| |
| def forest_predict(schema_madlib, model, source, output, |
| pred_type='response', **kwargs): |
| """ |
| Args: |
| @param schema_madlib: str, Name of MADlib schema |
| @param model: str, Name of table containing the forest model |
| @param source: str, Name of table containing prediction data |
| @param output: str, Name of table to output the results |
| @param pred_type: str, The type of output required: |
| 'response' gives the actual response values, |
| 'prob' gives the probability of the classes in a |
| classification model. |
| For regression model, only type='response' is defined. |
| |
| Returns: |
| None |
| |
| Side effect: |
| Creates an output table containing the prediction for given source table |
| |
| Throws: |
| None |
| """ |
| pred_type = 'response' if pred_type is None or pred_type == '' else pred_type |
| _validate_predict(model, source, output, pred_type) |
| |
| model_summary = add_postfix(model, "_summary") |
| model_group = add_postfix(model, "_group") |
| # obtain the cat_features and con_features from model table |
| summary_elements = plpy.execute("SELECT * FROM {0}".format(model_summary))[0] |
| cat_features = split_quoted_delimited_str(summary_elements["cat_features"]) |
| con_features = split_quoted_delimited_str(summary_elements["con_features"]) |
| id_col_name = summary_elements["id_col_name"] |
| grouping_cols = summary_elements.get("grouping_cols") # optional, default = None |
| dep_varname = summary_elements["dependent_varname"] |
| dep_levels = split_quoted_delimited_str(summary_elements["dependent_var_levels"]) |
| is_classification = summary_elements["is_classification"] |
| dep_type = summary_elements['dependent_var_type'] |
| null_proxy = summary_elements.get('null_proxy') # optional, default = None |
| |
| # pred_type='prob' is allowed only for classification |
| _assert(is_classification or pred_type == 'response', |
| "Random forest error: pred_type cannot be 'prob' for regression model.") |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, source, cat_features, con_features, |
| "cat_levels_in_text", "cat_n_levels", null_proxy) |
| |
| pred_name = ('"prob_{0}"' if pred_type == "prob" else |
| '"estimated_{0}"').format(dep_varname.replace('"', '').strip()) |
| |
| join_str = "," if grouping_cols is None else "JOIN" |
| using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")" |
| |
| if not is_classification: |
| majority_pred_expression = "AVG(aggregated_prediction)" |
| else: |
| dep_levels_array_str = py_list_to_sql_string(map(quote_literal, dep_levels), |
| 'TEXT', |
| long_format=True) |
| majority_pred_expression = ( |
| "({0})[{1}.mode(aggregated_prediction + 1)]::TEXT". |
| format(dep_levels_array_str, schema_madlib)) |
| |
| if is_psql_boolean_type(dep_type): |
| # some platforms don't have text to boolean cast. We manually check the string. |
| majority_pred_cast_str = ("(case {majority_pred_expression} " |
| " when 'true' then true " |
| " when 'false' then false " |
| " end)::BOOLEAN AS {pred_name}") |
| else: |
| majority_pred_cast_str = "({majority_pred_expression})::{dep_type} AS {pred_name}" |
| majority_pred_cast_str = majority_pred_cast_str.format(**locals()) |
| num_trees_grown = plpy.execute( |
| "SELECT count(DISTINCT sample_id) FROM {0}".format(model))[0]['count'] |
| |
| if pred_type == "response" or not is_classification: |
| sql_prediction = """ |
| CREATE TABLE {output} AS |
| SELECT |
| {id_col_name}, |
| {majority_pred_cast_str} |
| FROM ( |
| SELECT |
| {id_col_name}, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[]) AS aggregated_prediction |
| FROM |
| {source} |
| {join_str} |
| {model_group} |
| {using_str} |
| JOIN |
| {model} |
| USING (gid) |
| ) prediction_agg |
| GROUP BY {id_col_name} |
| """.format(**locals()) |
| else: |
| len_dep_levels = len(dep_levels) |
| normalized_majority_pred = unique_string() |
| score_format = ', \n'.join([ |
| '{temp}[{j}] as "estimated_prob_{c}"'. |
| format(j=i+1, c=c.strip(' "'), temp=normalized_majority_pred) |
| for i, c in enumerate(dep_levels)]) |
| |
| sql_prediction = """ |
| CREATE TABLE {output} AS |
| SELECT |
| {id_col_name}, |
| {score_format} |
| FROM |
| ( |
| SELECT |
| {id_col_name}, |
| {schema_madlib}.discrete_distribution_agg( |
| prediction::integer, |
| 1., |
| {len_dep_levels} |
| )::double precision[] |
| AS {normalized_majority_pred} |
| FROM |
| ( |
| SELECT |
| {id_col_name}, |
| gid, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[]) as prediction |
| FROM |
| {source} |
| {join_str} |
| {model_group} |
| {using_str} |
| JOIN |
| {model} |
| USING (gid) |
| ) class_prediction_subq |
| GROUP BY gid, {id_col_name} |
| ) subq |
| """.format(**locals()) |
| |
| with MinWarning('warning'): |
| with OptimizerControl(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with HashaggControl(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| plpy.notice("sql_prediction:\n"+sql_prediction) |
| plpy.execute(sql_prediction) |
| |
| return None |
| # ------------------------------------------------------------ |
| |
| |
| def get_tree_surr(schema_madlib, model_table, gid, sample_id, **kwargs): |
| return get_tree(schema_madlib, model_table, gid, sample_id, |
| dot_format=False, disp_surr=True) |
| |
| |
| def get_tree(schema_madlib, model_table, gid, sample_id, |
| dot_format=True, verbose=False, disp_surr=False, **kwargs): |
| """Random forest tree display function""" |
| |
| _validate_get_tree(model_table, gid, sample_id) |
| if dot_format: |
| disp_surr = False # surrogates cannot be displayed in dot format |
| bytea8 = schema_madlib + '.bytea8' |
| |
| model_table_summary = add_postfix(model_table, "_summary") |
| model_table_group = add_postfix(model_table, "_group") |
| summary = plpy.execute("SELECT * FROM {model_table_summary}". |
| format(model_table_summary=model_table_summary))[0] |
| dep_levels = summary["dependent_var_levels"] |
| dep_levels = [''] if not dep_levels else split_quoted_delimited_str(dep_levels) |
| table_name = summary["source_table"] |
| is_regression = not summary["is_classification"] |
| cat_features_str = split_quoted_delimited_str(summary['cat_features']) |
| con_features_str = split_quoted_delimited_str(summary['con_features']) |
| |
| with MinWarning('warning'): |
| sql_tree_result = """ |
| SELECT |
| tree, |
| cat_levels_in_text, |
| cat_n_levels |
| FROM |
| {model_table} |
| JOIN |
| {model_table_group} |
| USING (gid) |
| WHERE sample_id = {sample_id} |
| AND gid = {gid} |
| """.format(**locals()) |
| plpy.notice("sql_tree_result:\n"+sql_tree_result) |
| tree_result = plpy.execute(sql_tree_result) |
| |
| if not tree_result: |
| plpy.warning("no tree found by the given gid and sample_id, exiting...") |
| tree = tree_result[0] |
| |
| if tree['cat_levels_in_text']: |
| cat_levels_in_text = tree['cat_levels_in_text'] |
| cat_n_levels = tree['cat_n_levels'] |
| else: |
| cat_levels_in_text = [] |
| cat_n_levels = [] |
| |
| return_str_list = [] |
| if not disp_surr: |
| return_str_list.append(_get_display_header(table_name, dep_levels, |
| is_regression, dot_format)) |
| else: |
| return_str_list.append(""" |
| ------------------------------------- |
| Surrogates for internal nodes |
| ------------------------------------- |
| """) |
| |
| with MinWarning('warning'): |
| if disp_surr: |
| # Output only surrogate information for the internal nodes of tree |
| sql = """SELECT {0}._display_decision_tree_surrogate( |
| $1, $2, $3, $4, $5) as display_tree |
| """.format(schema_madlib) |
| # execute sql to get display string |
| sql_plan = plpy.prepare(sql, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]']) |
| tree_display = plpy.execute( |
| sql_plan, [tree['tree'], cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels])[0] |
| else: |
| # Output the splits in each node of the tree |
| if dot_format: |
| sql_display = """ |
| SELECT {0}._display_decision_tree( |
| $1, $2, $3, $4, $5, $6, '{1}', {2} |
| ) as display_tree |
| """.format(schema_madlib, "", verbose) |
| else: |
| sql_display = """ |
| SELECT {0}._display_text_decision_tree( |
| $1, $2, $3, $4, $5, $6 |
| ) as display_tree |
| """.format(schema_madlib) |
| |
| plpy.notice("sql_display:\n"+sql_display) |
| plan_display = plpy.prepare(sql_display, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]', 'text[]']) |
| tree_display = plpy.execute( |
| plan_display, [tree['tree'], cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels, |
| dep_levels])[0] |
| |
| return_str_list.append(tree_display["display_tree"]) |
| if dot_format: |
| return_str_list.append("} //---end of digraph--------- ") |
| return ("\n".join(return_str_list)) |
| # ------------------------------------------------------------ |
| |
| |
| def _calculate_oob_prediction( |
| schema_madlib, model_table, cat_features_info_table, con_splits_table, |
| oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, |
| con_features, source_table, grouping_cols, grp_key_to_grp_cols, dep, |
| num_permutations, is_classification, importance, num_bins, filter_null, null_proxy=None): |
| """Calculate predication for out-of-bag sample""" |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, source_table, cat_features, con_features, |
| "cat_levels_in_text", "cat_n_levels", null_proxy) |
| |
| join_str = "," if grouping_cols is None else "JOIN" |
| using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")" |
| |
| oob_var_dist_view = unique_string() |
| if importance: |
| if con_features: |
| initialized_float_array = ("{0}.array_of_float({1})::integer[]". |
| format(schema_madlib, len(con_features))) |
| else: |
| initialized_float_array = "NULL::integer[]" |
| sql_create_oob_var_dist_view = """ |
| CREATE VIEW {oob_var_dist_view} AS |
| SELECT |
| gid, |
| {schema_madlib}.vectorized_distribution_agg( |
| {schema_madlib}.array_scalar_add( |
| {cat_features_str}::integer[], |
| 1 -- -1 shifted to 0 for null values |
| ), |
| {schema_madlib}.array_scalar_add( |
| cat_n_levels::integer[], |
| 1 -- -1 shifted to 0 for null values |
| ) |
| ) AS cat_feature_distributions, |
| {schema_madlib}.vectorized_distribution_agg( |
| {schema_madlib}.array_scalar_add( |
| {schema_madlib}._get_bin_indices_by_values( |
| {con_features_str}::double precision[], |
| con_splits |
| ), -- bin_indices, -1 for NaN |
| 1 -- -1 shifted to 0 for null values |
| ), |
| {schema_madlib}.array_fill({initialized_float_array}, |
| ({num_bins} + 1)::integer) |
| -- level of any continuous feature == num_bins |
| ) AS con_index_distributions |
| FROM |
| {oob_view} |
| {join_str} |
| {grp_key_to_grp_cols} |
| {using_str} |
| JOIN |
| {cat_features_info_table} |
| USING (gid) |
| JOIN |
| {con_splits_table} |
| USING (gid) |
| GROUP BY gid |
| """.format(**locals()) |
| else: |
| sql_create_oob_var_dist_view = """ |
| CREATE VIEW {oob_var_dist_view} AS |
| SELECT |
| gid, |
| NULL::float8[] AS cat_feature_distributions, |
| NULL::float8[] AS con_index_distributions |
| FROM {cat_features_info_table} |
| """.format(**locals()) |
| |
| plpy.notice("sql_create_oob_var_dist_view : " + str(sql_create_oob_var_dist_view)) |
| plpy.execute(sql_create_oob_var_dist_view) |
| |
| sql_oob_predict = """ |
| INSERT INTO {oob_prediction_table} |
| SELECT |
| {id_col_name}, |
| sample_id, |
| gid, |
| {dep} AS dep, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[] |
| ) AS oob_prediction, |
| {schema_madlib}._rf_cat_imp_score( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[], |
| cat_info.cat_n_levels::integer[], |
| {num_permutations}, |
| {dep}, |
| {is_classification}, |
| cat_feature_distributions -- if distribution is NULL, returns NULL |
| ) AS cat_permuted_imp_score, |
| {schema_madlib}._rf_con_imp_score( |
| tree, |
| {cat_features_str}::integer[], |
| {con_features_str}::double precision[], |
| con_info.con_splits, |
| {num_permutations}, |
| {dep}, |
| {is_classification}, |
| con_index_distributions -- if distribution is NULL, returns NULL |
| ) AS con_permuted_imp_score |
| FROM |
| {oob_view} |
| {join_str} |
| {grp_key_to_grp_cols} |
| {using_str} |
| JOIN |
| ( |
| SELECT * |
| FROM {model_table} |
| WHERE sample_id = {sample_id} |
| ) m |
| USING (gid) |
| JOIN |
| {cat_features_info_table} cat_info |
| USING (gid) |
| JOIN |
| {con_splits_table} con_info |
| USING (gid) |
| LEFT OUTER JOIN -- empty if variable importance is disabled |
| {oob_var_dist_view} |
| USING (gid) |
| WHERE {filter_null} |
| """.format(**locals()) |
| plpy.notice("sql_oob_predict : " + str(sql_oob_predict)) |
| plpy.execute(sql_oob_predict) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, |
| grp_key_to_grp_cols, bins): |
| |
| bytea8 = schema_madlib + '.bytea8' |
| bytea8arr = schema_madlib + '.bytea8[]' |
| if grouping_cols is None: |
| sql_create_con_splits_table = """ |
| CREATE TEMP TABLE {con_splits_table} AS |
| SELECT |
| 1 AS gid, |
| $1 AS con_splits |
| """.format(con_splits_table=con_splits_table) |
| plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table) |
| sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table, |
| [bytea8]) |
| plpy.execute(sql_create_con_splits_plan, [bins['con']]) |
| else: |
| sql_create_con_splits_table = """ |
| CREATE TABLE {con_splits_table} AS |
| SELECT |
| gid, |
| con_splits |
| FROM |
| {grp_key_to_grp_cols} |
| JOIN |
| ( |
| SELECT |
| unnest($1) as grp_key, |
| unnest($2) as con_splits |
| ) subq |
| USING (grp_key) |
| """.format(**locals()) |
| plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table) |
| sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table, |
| ['text[]', bytea8arr]) |
| plpy.execute(sql_create_con_splits_plan, |
| [bins['grp_key_con'], bins['con']]) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _calculate_oob_variable_importance( |
| schema_madlib, oob_prediction_table, is_classification, |
| importance_table, n_cat, n_con): |
| sql_create_empty_imp_tbl = """ |
| CREATE TEMP TABLE {importance_table} ( |
| gid integer, |
| oob_cat_var_importance float8[], |
| oob_con_var_importance float8[] |
| ); |
| """.format(**locals()) |
| plpy.notice("sql_create_empty_imp_tbl:\n" + sql_create_empty_imp_tbl) |
| plpy.execute(sql_create_empty_imp_tbl) |
| if not is_classification: |
| # squared error |
| score_expression = "-((oob_prediction - dep)^2)".format(**locals()) |
| else: |
| # misclassification |
| score_expression = """ |
| CASE WHEN dep = oob_prediction::integer |
| THEN 1. |
| ELSE 0. |
| END""".format(**locals()) |
| |
| sample_score_view = unique_string() |
| sql_create_sample_score_view = """ |
| CREATE VIEW {sample_score_view} AS |
| SELECT |
| sample_id, |
| gid, |
| count(*) as size, |
| sum({score_expression}) as score, |
| {schema_madlib}.sum(cat_permuted_imp_score::FLOAT8[]) AS cat_permuted_imp_score, |
| {schema_madlib}.sum(con_permuted_imp_score::FLOAT8[]) AS con_permuted_imp_score |
| FROM |
| {oob_prediction_table} |
| GROUP BY sample_id, gid |
| """.format(**locals()) |
| plpy.notice("sql_create_sample_score_view:\n" + sql_create_sample_score_view) |
| plpy.execute(sql_create_sample_score_view) |
| |
| sql_create_importance_table = """ |
| INSERT INTO {importance_table} |
| SELECT |
| -- Shift all values if an importance value is negative. |
| -- This is performed by subtracting the minimum importance value |
| -- (only if the minimum is negative) |
| gid, |
| {schema_madlib}.array_scalar_add( |
| cat_var_imp, |
| -{schema_madlib}.array_min( |
| array_append(array_cat(cat_var_imp, con_var_imp), |
| 0.0::double precision)) |
| ), |
| {schema_madlib}.array_scalar_add( |
| con_var_imp, |
| -{schema_madlib}.array_min( |
| array_append(array_cat(cat_var_imp, con_var_imp), |
| 0.0::double precision)) |
| ) |
| FROM ( |
| -- Compute the average importance over the OOB data where, |
| -- importance = original score - permuted score |
| -- Since permuted score is a vector and original score is a scalar, |
| -- the signs are inverted and then fixed by dividing with |
| -- negative of size. |
| SELECT |
| gid, |
| {schema_madlib}.array_avg( |
| {schema_madlib}.array_scalar_mult( |
| {schema_madlib}.array_scalar_add( |
| cat_permuted_imp_score, |
| -score::float8 |
| ), |
| (-1. / size)::float8 |
| ), |
| FALSE -- not use absolute values |
| ) AS cat_var_imp, |
| {schema_madlib}.array_avg( |
| {schema_madlib}.array_scalar_mult( |
| {schema_madlib}.array_scalar_add( |
| con_permuted_imp_score, |
| -score::float8 |
| ), |
| (-1. / size)::float8 |
| ), |
| FALSE -- not use absolute values |
| ) AS con_var_imp |
| FROM |
| {sample_score_view} |
| GROUP BY gid |
| ) q |
| """.format(**locals()) |
| plpy.notice("sql_create_importance_table:\n" + sql_create_importance_table) |
| plpy.execute(sql_create_importance_table) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _calculate_oob_error(schema_madlib, oob_prediction_table, oob_error_table, |
| id_col_name, is_classification): |
| """Calculate out-of-bag error for oob samples""" |
| if not is_classification: |
| residual_expression = "(dep - forest_prediction)^2".format(**locals()) |
| forest_prediction_agg = 'avg' |
| else: |
| residual_expression = """ |
| CASE WHEN dep = forest_prediction::integer |
| THEN 0. |
| ELSE 1. |
| END""".format(**locals()) |
| forest_prediction_agg = "{schema_madlib}.mode".format(**locals()) |
| |
| sql_compute_oob_error = """ |
| CREATE TABLE {oob_error_table} AS |
| SELECT |
| gid, |
| avg({residual_expression}) AS oob_error |
| FROM |
| ( |
| SELECT |
| gid, |
| dep, |
| {forest_prediction_agg}(oob_prediction) AS forest_prediction |
| FROM |
| {oob_prediction_table} |
| GROUP BY gid, {id_col_name}, dep |
| ) prediction_subq |
| GROUP BY gid |
| """.format(**locals()) |
| plpy.notice("sql_compute_oob_error : " + str(sql_compute_oob_error)) |
| plpy.execute(sql_compute_oob_error) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_summary_table(**kwargs): |
| kwargs['features'] = ','.join(kwargs['cat_features'] + kwargs['con_features']) |
| kwargs['dep_type'] = get_expr_type(kwargs['dependent_variable'], |
| kwargs['training_table_name']) |
| if kwargs['dep_list']: |
| if is_psql_boolean_type(kwargs['dep_type']): |
| # Special handling for boolean since Python booleans start with |
| # capitals (i.e False instead of false) |
| # Note: dep_list is sorted, hence 'false' will be first |
| kwargs['dep_list_str'] = "'false, true'" |
| else: |
| kwargs['dep_list_str'] = '$__dep_list__${0}$__dep_list__$'.format( |
| ','.join(map(str, kwargs['dep_list']))) |
| else: |
| kwargs['dep_list_str'] = "NULL" |
| kwargs['indep_type'] = ', '.join(kwargs['all_cols_types'][col] |
| for col in (kwargs['cat_features'] + |
| kwargs['con_features'])) |
| kwargs['cat_features_str'] = ','.join(kwargs['cat_features']) |
| kwargs['con_features_str'] = ','.join(kwargs['con_features']) |
| if kwargs['grouping_cols']: |
| kwargs['grouping_cols_str'] = "'{grouping_cols}'".format(**kwargs) |
| else: |
| kwargs['grouping_cols_str'] = 'NULL' |
| kwargs['n_rows_skipped'] = kwargs['n_all_rows'] - kwargs['n_rows'] |
| |
| kwargs['output_table_summary'] = add_postfix(kwargs['output_table_name'], "_summary") |
| sql_create_summary_table = """ |
| CREATE TABLE {output_table_summary} AS |
| SELECT |
| 'forest_train'::text AS method, |
| '{is_classification}'::boolean AS is_classification, |
| '{training_table_name}'::text AS source_table, |
| '{output_table_name}'::text AS model_table, |
| '{id_col_name}'::text AS id_col_name, |
| '{dependent_variable}'::text AS dependent_varname, |
| '{features}'::text AS independent_varnames, |
| '{cat_features_str}'::text AS cat_features, |
| '{con_features_str}'::text AS con_features, |
| {grouping_cols_str}::text AS grouping_cols, |
| {num_trees}::integer AS num_trees, |
| {num_random_features}::integer AS num_random_features, |
| {max_tree_depth}::integer AS max_tree_depth, |
| {min_split}::integer AS min_split, |
| {min_bucket}::integer AS min_bucket, |
| {num_bins}::integer AS num_splits, |
| {verbose}::boolean AS verbose, |
| {importance}::boolean AS importance, |
| {num_permutations}::integer AS num_permutations, |
| {num_groups}::integer AS num_all_groups, |
| {num_failed_groups}::integer AS num_failed_groups, |
| {n_rows}::integer AS total_rows_processed, |
| {n_rows_skipped}::integer AS total_rows_skipped, |
| {dep_list_str}::text AS dependent_var_levels, |
| '{dep_type}'::text AS dependent_var_type, |
| '{indep_type}'::text AS independent_var_types, |
| '{null_proxy}'::text AS null_proxy |
| """.format(**kwargs) |
| plpy.notice("sql_create_summary_table:\n" + sql_create_summary_table) |
| plpy.execute(sql_create_summary_table) |
| # ------------------------------------------------------------ |
| |
| |
| def _create_group_table( |
| schema_madlib, output_table_name, impurity_imp_table, oob_error_table, |
| importance_table, cat_features_info_table, grp_key_to_grp_cols, |
| grouping_cols, tree_terminated): |
| """ Create the group table for random forest""" |
| if importance_table: |
| impurity_var_importance_query = """ |
| SELECT |
| gid, |
| {schema_madlib}.array_avg(impurity_var_importance, False) AS impurity_var_importance |
| FROM {impurity_imp_table} |
| GROUP BY gid |
| """.format(**locals()) |
| oob_var_importance_str = (", array_cat(oob_cat_var_importance, " |
| "oob_con_var_importance) AS oob_var_importance") |
| impurity_var_importance_str = ", impurity_var_importance" |
| left_join_importance_table_str = ("LEFT OUTER JOIN {0} USING (gid)". |
| format(importance_table)) |
| join_impurity_table_str = ("JOIN ({0}) q USING (gid)". |
| format(impurity_var_importance_query)) |
| else: |
| oob_var_importance_str = '' |
| impurity_var_importance_str = '' |
| left_join_importance_table_str = '' |
| join_impurity_table_str = '' |
| |
| grouping_cols_str = ('' if grouping_cols is None else grouping_cols + ",") |
| group_table_name = add_postfix(output_table_name, "_group") |
| sql_create_group_table = """ |
| CREATE TABLE {group_table_name} AS |
| SELECT |
| gid, |
| {grouping_cols_str} |
| grp_finished AS success, |
| cat_n_levels, |
| cat_levels_in_text, |
| oob_error |
| {oob_var_importance_str} |
| {impurity_var_importance_str} |
| FROM |
| {oob_error_table} |
| JOIN |
| {grp_key_to_grp_cols} |
| USING (gid) |
| JOIN ( |
| SELECT |
| unnest($1) AS grp_key, |
| unnest($2) AS grp_finished |
| ) tree_terminated |
| USING (grp_key) |
| JOIN |
| {cat_features_info_table} |
| USING (gid) |
| {left_join_importance_table_str} |
| {join_impurity_table_str} |
| """.format(**locals()) |
| plpy.notice("sql_create_group_table:\n" + sql_create_group_table) |
| plan_create_group_table = plpy.prepare(sql_create_group_table, |
| ['text[]', 'boolean[]']) |
| plpy.execute(plan_create_group_table, |
| [tree_terminated.keys(), |
| [True if v == 1 else False for v in tree_terminated.values()]]) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _create_empty_result_table(schema_madlib, output_table_name, |
| impurity_imp_table, importance): |
| """Create the result table for all trees in the forest""" |
| sql_create_empty_result_table = """ |
| CREATE TABLE {output_table_name} ( |
| gid integer, |
| sample_id integer, |
| tree {schema_madlib}.bytea8) |
| """.format(**locals()) |
| plpy.notice("sql_create_empty_result_table:\n" + sql_create_empty_result_table) |
| plpy.execute(sql_create_empty_result_table) |
| if importance: |
| plpy.execute(""" |
| CREATE TEMP TABLE {impurity_imp_table} ( |
| gid integer, |
| sample_id integer, |
| impurity_var_importance double precision[]) |
| """.format(**locals())) |
| # ------------------------------------------------------------ |
| |
| |
| def _insert_into_result_table(schema_madlib, tree_states, output_table_name, |
| impurity_imp_table, grp_key_to_grp_cols, |
| sample_id, importance, grouping_cols): |
| if grouping_cols: |
| grp_join_sql = """JOIN {0} USING (grp_key)""".format(grp_key_to_grp_cols) |
| gid_str = "gid" |
| grp_key = py_list_to_sql_string([tree_state['grp_key'] |
| for tree_state in tree_states], |
| 'TEXT[]', False) |
| grp_key_sql = "unnest({0}) AS grp_key, ".format(grp_key) |
| else: |
| grp_join_sql = '' |
| gid_str = "1 AS gid" |
| grp_key_sql = '' |
| |
| sql = """ |
| INSERT INTO {output_table_name} |
| SELECT |
| {gid_str}, |
| {sample_id} AS sample_id, |
| tree |
| FROM ( |
| SELECT |
| {grp_key_sql} |
| unnest($1) AS tree |
| ) grp_key_to_tree |
| {grp_join_sql} |
| """.format(**locals()) |
| sql_plan = plpy.prepare(sql, ['{0}.bytea8[]'.format(schema_madlib)]) |
| plpy.execute(sql_plan, [[tree_state['tree_state'] for tree_state in tree_states]]) |
| |
| if importance: |
| grp_imp_values = [] |
| for tree_state in tree_states: |
| importance_vector = py_list_to_sql_string( |
| tree_state['impurity_var_importance'], |
| 'DOUBLE PRECISION', |
| True) |
| grp_imp_values.append("({0}, {1})". |
| format(quote_literal(tree_state['grp_key']), |
| importance_vector)) |
| sql = """ |
| INSERT INTO {impurity_imp_table} |
| SELECT |
| {gid_str}, |
| {sample_id} AS sample_id, |
| impurity_var_importance |
| FROM ( |
| VALUES |
| {grp_imp_values_str} |
| ) grp_key_to_importance(grp_key, impurity_var_importance) |
| {grp_join_sql} |
| """.format(grp_imp_values_str=', \n'.join(grp_imp_values), |
| **locals()) |
| plpy.execute(sql) |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _forest_validate_args( |
| training_table_name, output_table_name, id_col_name, |
| list_of_features, dependent_variable, list_of_features_to_exclude, |
| grouping_cols, num_trees, num_random_features, n_perm, |
| max_tree_depth, min_split, min_bucket, num_bins, sample_ratio): |
| """ Validate the arguments for the random forest training function""" |
| |
| input_tbl_valid(training_table_name, 'Random forest') |
| cols_in_tbl_valid(training_table_name, [id_col_name], 'Random forest') |
| |
| output_tbl_valid(output_table_name, 'Random forest') |
| output_tbl_valid(add_postfix(output_table_name, "_group"), 'Random forest') |
| output_tbl_valid(add_postfix(output_table_name, "_summary"), 'Random forest') |
| |
| _assert(not (list_of_features is None or list_of_features.strip().lower() == ''), |
| "Random forest error: Features to include is empty.") |
| if list_of_features.strip() != '*': |
| _assert(is_var_valid(training_table_name, list_of_features), |
| "Random forest error: Invalid feature list ({0})". |
| format(list_of_features)) |
| _assert(not (dependent_variable is None or dependent_variable.strip().lower() == ''), |
| "Random forest error: Dependent variable is empty.") |
| _assert(is_var_valid(training_table_name, dependent_variable), |
| "Random forest error: Invalid dependent variable ({0}).". |
| format(dependent_variable)) |
| |
| if grouping_cols is not None and grouping_cols.strip() != '': |
| _assert(is_var_valid(training_table_name, grouping_cols), |
| "Random forest error: Invalid grouping column argument.") |
| |
| _assert(num_trees > 0, "Random forest error: num_trees must be positive.") |
| _assert(n_perm > 0, "Random forest error: num_permutations must be positive.") |
| if num_random_features is not None: |
| _assert(num_random_features > 0, |
| "Random forest error: num_random_features must be positive.") |
| _assert(max_tree_depth >= 0 and max_tree_depth <= 15, |
| "Random forest error: max_tree_depth must be non-negative and less than 16.") |
| _assert(min_split > 0, "Random forest error: min_split must be positive.") |
| _assert(min_bucket > 0, "Random forest error: min_bucket must be positive.") |
| _assert(num_bins > 1, "Random forest error: number of bins must be at least 2.") |
| _assert(sample_ratio > 0 and sample_ratio <= 1, |
| "Random forest error: sample_ratio must be in (0, 1].") |
| # ------------------------------------------------------------ |
| |
| |
| def _validate_predict(model, source, output, pred_type): |
| """Validations for input arguments""" |
| input_tbl_valid(model, 'Random forest') |
| cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest') |
| |
| input_tbl_valid(add_postfix(model, "_group"), 'Random forest') |
| cols_in_tbl_valid(add_postfix(model, "_group"), |
| ['gid', 'cat_n_levels', 'cat_levels_in_text'], |
| 'Random forest') |
| |
| input_tbl_valid(add_postfix(model, "_summary"), 'Random forest') |
| cols_in_tbl_valid(add_postfix(model, "_summary"), |
| ["grouping_cols", "id_col_name", "dependent_varname", |
| "cat_features", "con_features", "is_classification"], |
| 'Random forest') |
| |
| input_tbl_valid(source, 'Random forest') |
| |
| output_tbl_valid(output, 'Random forest') |
| |
| _assert(pred_type in ('response', 'prob'), |
| "Random forest error: pred_type should be 'response' or 'prob'") |
| # ------------------------------------------------------------------------------ |
| |
| |
| # ------------------------------------------------------------------------------ |
| # -- 'get_var_importance' related code below ----------------------------------- |
| # -- This applies to RF and DT ------------------------------------------------- |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _validate_get_tree(model, gid, sample_id): |
| """Validations for input arguments""" |
| input_tbl_valid(model, 'Random forest') |
| cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest') |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _validate_var_importance_input(model_table, summary_table, output_table): |
| _assert(table_exists(model_table), |
| "Recursive Partitioning: Model table does not exist.") |
| _assert(table_exists(summary_table), |
| "Recursive Partitioning: Model summary table does not exist.") |
| _assert(not table_exists(output_table), |
| "Recursive Partitioning: Output table already exists.") |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _is_random_forest_model(summary_table): |
| method_name = plpy.execute("SELECT method FROM " + summary_table)[0]['method'] |
| return method_name == 'forest_train' |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _is_impurity_importance_in_model(output_table, summary_table, is_RF=False): |
| """ Check if the model has a column named impurity_var_importance |
| """ |
| error_header = "Random forest" if is_RF else "Decision tree" |
| _assert(table_exists(output_table), |
| "{0}: Model table ({1}) does not exist.". |
| format(error_header, output_table)) |
| |
| if is_RF: |
| # importance = True indicates that importance scores were calculated |
| is_importance_set = plpy.execute("SELECT importance FROM {summary_table}". |
| format(**locals()))[0]['importance'] |
| _assert(is_importance_set, |
| "{0}: The model does not have importance information.". |
| format(error_header)) |
| return columns_exist_in_table(output_table, ['impurity_var_importance']) |
| # ------------------------------------------------------------------------------ |
| |
| |
| @MinWarning("error") |
| def get_var_importance(schema_madlib, model_table, output_table, **kwargs): |
| """ Create table capturing importance scores for each feature. |
| For DT, this function will record the impurity importance score if it exists. |
| For RF, this function will record the oob variable importance and impurity |
| importance (if available) for each variable. |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param model_table: str, Model table name |
| @param output_table: str, Output table name |
| |
| """ |
| # Validate parameters |
| summary_table = add_postfix(model_table, "_summary") |
| _validate_var_importance_input(model_table, |
| summary_table, |
| output_table) |
| is_RF = _is_random_forest_model(summary_table) |
| |
| # 'importance_model_table' is the table containing variable importance values. |
| # For RF, it is placed in <model_table>_group as opposed to <model_table> |
| # for DT. |
| importance_model_table = (model_table if not is_RF else |
| add_postfix(model_table, "_group")) |
| grouping_cols = plpy.execute("SELECT grouping_cols FROM {summary_table}". |
| format(**locals()))[0]['grouping_cols'] |
| if grouping_cols: |
| grouping_cols_comma = add_postfix(grouping_cols, ", ") |
| else: |
| grouping_cols_comma = '' |
| is_impurity_imp_col_present = _is_impurity_importance_in_model( |
| importance_model_table, summary_table, is_RF=is_RF) |
| |
| # convert importance to percentages |
| normalization_target = 100.0 |
| |
| def _unnest_normalize(input_array_str): |
| return (""" |
| unnest({0}.normalize_sum_array({1}::double precision[], |
| {2}::double precision)) |
| """.format(schema_madlib, input_array_str, normalization_target)) |
| |
| if is_RF: |
| if is_impurity_imp_col_present: |
| # In versions >= 1.15, the OOB variable importance is captured |
| # in a single column: 'oob_var_importance'. |
| oob_var_importance_str = ( |
| "{0} AS oob_var_importance,". |
| format(_unnest_normalize('oob_var_importance'))) |
| impurity_var_importance_str = ( |
| "{0} AS impurity_var_importance". |
| format(_unnest_normalize('impurity_var_importance'))) |
| else: |
| # In versions < 1.15, the OOB variable importance was captured in |
| # two different columns: 'cat_var_importance' and 'con_var_importance' |
| oob_var_importance_str = ( |
| "{0} AS oob_var_importance". |
| format(_unnest_normalize( |
| "array_cat(cat_var_importance, con_var_importance"))) |
| impurity_var_importance_str = '' |
| else: |
| # Decision tree models don't have a OOB variable importance |
| _assert(is_impurity_imp_col_present, |
| "Decision tree: Impurity importance not present in output table") |
| oob_var_importance_str = '' |
| impurity_var_importance_str = ( |
| "{0} AS impurity_var_importance". |
| format(_unnest_normalize('impurity_var_importance'))) |
| |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT {grouping_cols_comma} |
| unnest(regexp_split_to_array(independent_varnames, ',')) AS feature, |
| {oob_var_importance_str} |
| {impurity_var_importance_str} |
| FROM {importance_model_table}, {summary_table} |
| """.format(**locals())) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def forest_predict_help_message(schema_madlib, message, **kwargs): |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Random Forest Prediction |
| |
| Random forests use a forest-based predictive model to predict |
| the value of a target variable based on several input variables. |
| This is the function to make predictions using the model trained |
| by the function 'forest_train'. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.forest_predict('usage'); |
| For an example on using this function: |
| SELECT {schema_madlib}.forest_predict('example'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.forest_predict( |
| 'forest_model', -- Model table name (output of forest_train) |
| 'new_data_table', -- Source data table |
| 'output_table', -- The name of the table storing the predictions |
| 'type' -- Type of prediction output, 'response' or 'prob' |
| ); |
| |
| Note: The 'new_data_table' should have the same 'id_col_name' column as used |
| in the training function. This is used to corelate the prediction data row with |
| the actual prediction in the output table. |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the '<id_col_name>' column giving |
| the 'id' for each prediction and the prediction columns for the response |
| variable (also called as dependent variable). |
| |
| If prediction type = 'response', then the table has a single column with the |
| prediction value of the response. The type of this column depends on the type |
| of the response variable used during training. The response value for regression |
| is the average prediction across all trees, and is the majority vote in |
| the case of classification. |
| |
| If prediction type = 'prob', then the table has multiple columns, one for each |
| possible value of the response variable. The columns are labeled as |
| 'estimated_prob_<dep value>', where <dep value> represents for each value |
| of the response. This is only for the classification models, and the value |
| is the fraction of votes in each category. |
| |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.forest_predict('usage')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| |
| |
| def _importance_help_message(schema_madlib, message, **kwargs): |
| """ Help message for Decision Tree get_var_importance |
| """ |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Decision Tree/Random Forest Importance Values Display |
| |
| Create a table to record the importance values for a decision |
| tree (trained using {schema_madlib}.tree_train) or a random |
| forest (trained using {schema_madlib}.forest_train). |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.get_var_importance('usage'); |
| For an example on using this function: |
| SELECT {schema_madlib}.get_var_importance('example'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.get_var_importance( |
| 'model_table', -- Model table name (output of tree_train/forest_train) |
| 'output_table', -- Table name to store the predictions |
| ); |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has three columns. |
| 'feature' : The name of the feature |
| 'impurity_var_importance' : Impurity importance score for |
| the variable. This column will not be available in random |
| forest models unless the importance parameter is set to True |
| during training. |
| 'oob_var_importance' : Out-of-bag variable importance score |
| for the variable. This column will not be available for |
| decision tree models. |
| """ |
| elif message.lower().strip() in ['example', 'examples']: |
| help_string = """ |
| ------------------------------------------------------------ |
| EXAMPLE |
| ------------------------------------------------------------ |
| -- Assuming the example of tree_train() or forest_train() |
| has been run |
| SELECT {schema_madlib}.get_var_importance('train_output','imp_output'); |
| SELECT * FROM imp_output; |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.get_var_importance('usage')" |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |