| # coding=utf-8 |
| """ |
| @file decision_tree.py_in |
| |
| @brief Decision Tree: Driver functions |
| |
| @namespace decision_tree |
| """ |
| |
| from __future__ import division |
| import plpy |
| from math import sqrt |
| from operator import itemgetter |
| from itertools import groupby |
| from collections import Iterable |
| |
| from internal.db_utils import quote_literal |
| |
| from utilities.control import MinWarning |
| from utilities.control import OptimizerControl |
| from utilities.control import HashaggControl |
| from utilities.utilities import _assert |
| from utilities.utilities import _array_to_string |
| from utilities.utilities import _check_groups |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import unique_string |
| from utilities.utilities import add_postfix |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type |
| from utilities.utilities import py_list_to_sql_string |
| from utilities.utilities import split_quoted_delimited_str |
| from utilities.utilities import unique_string |
| |
| from utilities.validate_args import _get_table_schema_names |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import explicit_bool_to_text |
| from utilities.validate_args import get_cols |
| from utilities.validate_args import get_cols_and_types |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import is_var_valid |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import unquote_ident |
| |
| from validation.cross_validation import cross_validation_grouping_w_params |
| # ------------------------------------------------------------ |
| |
| |
| def _tree_validate_args( |
| split_criterion, training_table_name, output_table_name, |
| id_col_name, list_of_features, dependent_variable, |
| list_of_features_to_exclude, grouping_cols, weights, max_depth, |
| min_split, min_bucket, n_bins, cp, n_folds, **kwargs): |
| """ Validate the arguments |
| """ |
| _assert(training_table_name and |
| training_table_name.strip().lower() not in ('null', ''), |
| "Decision tree error: Invalid data table.") |
| _assert(table_exists(training_table_name), |
| "Decision tree error: Data table is missing.") |
| |
| _assert(not table_exists(output_table_name, only_first_schema=True), |
| "Decision tree error: Output table already exists.") |
| _assert(not table_exists(add_postfix(output_table_name, "_summary"), only_first_schema=True), |
| "Decision tree error: Output summary table already exists.") |
| |
| _assert(columns_exist_in_table(training_table_name, [id_col_name]), |
| "Decision tree error: ID column does not exist.") |
| |
| _assert(not (dependent_variable is None or dependent_variable.strip().lower() == ''), |
| "Decision tree error: Dependent variable is empty.") |
| |
| _assert(is_var_valid(training_table_name, dependent_variable), |
| "Decision tree error: Invalid dependent variable ({0}).". |
| format(dependent_variable)) |
| |
| _assert(list_of_features and list_of_features.strip(), |
| "Decision tree error: Features to include is empty.") |
| |
| if list_of_features.strip() != '*': |
| _assert(is_var_valid(training_table_name, list_of_features), |
| "Decision tree error: Invalid feature list ({0})". |
| format(list_of_features)) |
| |
| if grouping_cols: |
| _assert(is_var_valid(training_table_name, grouping_cols), |
| "Decision tree error: Invalid grouping column argument.") |
| |
| if weights is not None and weights.strip() != '': |
| _assert(is_var_valid(training_table_name, weights), |
| "Decision tree error: Invalid weights argument.") |
| _assert(max_depth >= 0 and max_depth < 100, |
| "Decision tree error: maximum tree depth must be positive and less than 100.") |
| |
| _assert(not table_is_empty(training_table_name, |
| _get_filter_str(dependent_variable, grouping_cols)), |
| "Decision tree error: Data table ({0}) is empty " |
| "(after filtering invalid tuples)". |
| format(training_table_name)) |
| |
| _assert(cp >= 0, "Decision tree error: cp must be non-negative.") |
| _assert(min_split > 0, "Decision tree error: min_split must be positive.") |
| _assert(min_bucket > 0, "Decision tree error: min_bucket must be positive.") |
| _assert(n_bins > 1, "Decision tree error: number of bins must be at least 2.") |
| _assert(n_folds >= 0, "Decision tree error: number of cross-validation " |
| "folds must be non-negative.") |
| # ------------------------------------------------------------ |
| |
| |
| def _validate_split_criterion(split_criterion, is_classification): |
| _assert(split_criterion.lower().strip() in ['mse', 'gini', 'cross-entropy', |
| 'entropy', 'misclass', |
| 'misclassification'], |
| "Decision tree error: Invalid split_criterion.") |
| if is_classification: |
| if split_criterion.lower().strip() == "mse": |
| plpy.error("Decision tree error: MSE is not a valid " |
| "split criterion for classification.") |
| else: |
| if split_criterion.lower().strip() != "mse": |
| plpy.warning("Decision tree: Using MSE as split criterion as it " |
| "is the only one supported for regression trees.") |
| split_criterion = "mse" |
| return split_criterion |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _get_features_to_use(schema_madlib, training_table_name, |
| list_of_features, list_of_features_to_exclude, |
| id_col_name, weights, dependent_variable, |
| grouping_cols=None, **kwargs): |
| """ Expand '*' syntax and exclude some features |
| |
| Ignore 'list_of_features_to_exclude' if 'list_of_features' is not '*' |
| |
| We also exclude from the features all grouping_cols, id_col_name, weights and dependent_variable |
| """ |
| # for some of the sets below we include the quoted name and the unquoted name |
| # to allow user to provide either form. Both forms are added in the exclusion |
| # list. |
| if grouping_cols: |
| group_set = set(split_quoted_delimited_str(grouping_cols)) |
| group_set |= set(unquote_ident(i) |
| for i in split_quoted_delimited_str(grouping_cols)) |
| else: |
| group_set = set() |
| other_col_set = set([id_col_name, weights, dependent_variable]) |
| other_col_set |= set(unquote_ident(i) |
| for i in [id_col_name, weights, dependent_variable]) |
| |
| if list_of_features.strip() == '*': |
| all_col_set = set(get_cols(training_table_name)) |
| exclude_set = set(split_quoted_delimited_str(list_of_features_to_exclude)) |
| feature_set = all_col_set - exclude_set |
| filtered_feature_list = list(feature_set - group_set - other_col_set) |
| else: |
| feature_list = split_quoted_delimited_str(list_of_features) |
| feature_exclude = split_quoted_delimited_str(list_of_features_to_exclude) |
| return_set = set(feature_list) - set(feature_exclude) - group_set - other_col_set |
| # instead of returning list(return_set) we create a list that has |
| # elements in same order as original feature_list |
| filtered_feature_list = [feat for feat in feature_list if feat in return_set] |
| |
| # check if any of the features is an array and expand the array |
| final_feature_list = [] |
| for feat in filtered_feature_list: |
| feat_type = get_expr_type(feat, training_table_name) |
| if '[]' in feat_type: |
| # expand array by indexing into it |
| feat_dims = plpy.execute(""" |
| SELECT array_lower({f}, 1) as l, |
| array_upper({f}, 1) as u |
| FROM {tbl} |
| LIMIT 1 |
| """.format(f=feat, tbl=training_table_name))[0] |
| final_feature_list += ["({f})[{i}]".format(f=feat, i=i) |
| for i in range(feat_dims['l'], feat_dims['u'] + 1)] |
| else: |
| final_feature_list.append(feat) |
| return final_feature_list |
| # ------------------------------------------------------------ |
| |
| |
| def _classify_features(feature_to_type, features): |
| """ Returns |
| 1) an array of categorical features |
| 2) an array of ordered categorical features |
| 3) an array of boolean categorical variables |
| 4) an array of continuous features |
| """ |
| # any column belonging to the following types are categorical |
| int_types = ['integer', 'smallint', 'bigint'] |
| text_types = ['text', 'varchar', 'character varying', 'char', 'character'] |
| boolean_types = ['boolean'] |
| cat_types = int_types + text_types + boolean_types |
| ordered_cat_types = int_types |
| |
| cat_features = [c for c in features if feature_to_type[c] in cat_types] |
| ordered_cat_features = [c for c in features |
| if feature_to_type[c] in ordered_cat_types] |
| # In order to be able to form an array, all categorical variables |
| # will be cast into TEXT type, but GPDB cannot cast a boolean |
| # directly into a text. Thus, boolean categorical variables |
| # need special treatment: cast them into integers before casting |
| # into text. |
| boolean_cat_features = [c for c in features |
| if feature_to_type[c] in boolean_types] |
| |
| # Integer types are not considered continuous |
| con_features = [c for c in features |
| if is_psql_numeric_type(feature_to_type[c], exclude=int_types)] |
| |
| return cat_features, ordered_cat_features, boolean_cat_features, con_features |
| # ------------------------------------------------------------ |
| |
| |
| def _extract_pruning_params(pruning_params_str): |
| """ |
| Args: |
| @param pruning_param: str, Parameters used for pruning the tree |
| cp = Cost-complexity for pruning |
| n_folds = Number of folds for cross-validation |
| Returns: |
| dict. A dictionary containing the pruning parameters |
| """ |
| default_dict = dict(cp=0, n_folds=0) |
| params_types = dict(cp=float, n_folds=int) |
| pruning_params = extract_keyvalue_params(pruning_params_str, |
| params_types, |
| default_dict) |
| if pruning_params['n_folds'] < 0: |
| plpy.error("Decision Tree error: Number of cross-validation folds ({0}) " |
| "must be non-negative".format(pruning_params['n_folds'])) |
| return pruning_params |
| # ------------------------------------------------------------ |
| |
| |
| def _get_tree_states(schema_madlib, is_classification, split_criterion, |
| training_table_name, output_table_name, id_col_name, |
| dependent_variable, dep_is_bool, |
| grouping_cols, cat_features, ordered_cat_features, |
| con_features, n_bins, boolean_cats, |
| min_split, min_bucket, weights, |
| max_depth, grp_key_to_cp, compute_cp_list=False, |
| max_n_surr=0, null_proxy=None, **kwargs): |
| """ |
| Args: |
| grp_key_to_cp : Dictionary, mapping from group key to the cp value |
| for each group. For the no grouping case, the |
| key is '' |
| """ |
| filter_dep = _get_filter_str(dependent_variable, grouping_cols) |
| if is_classification: |
| # For classifications, we also need to map dependent_variable to integers |
| n_rows, dep_list = _get_n_and_deplist(training_table_name, |
| dependent_variable, |
| filter_dep) |
| if dep_is_bool: |
| # false = 0, true = 1 |
| # This order is maintained in dep_list since |
| # _get_n_and_deplist returns a sorted list |
| dep_var_str = ("(CASE WHEN {0} THEN 1 ELSE 0 END)". |
| format(dependent_variable)) |
| |
| else: |
| dep_var_str = ("(CASE " + |
| "\n\t\t".join(["WHEN ({0})::text = $${1}$$ THEN {2}". |
| format(dependent_variable, str(c), i) |
| for i, c in enumerate(dep_list)]) + |
| "\nEND)") |
| else: |
| n_rows = long(plpy.execute("SELECT count(*)::bigint " |
| "FROM {src} " |
| "WHERE {filter}". |
| format(src=training_table_name, |
| filter=filter_dep))[0]['count']) |
| dep_var_str = dependent_variable |
| dep_list = [] |
| |
| dep_n_levels = len(dep_list) if dep_list else 1 |
| cat_features_info_table = unique_string() |
| if not grouping_cols: # non-grouping case |
| # 3) Find the splitting bins, one dict containing two arrays: |
| # categorical bins and continuous bins |
| bins = _get_bins(schema_madlib, training_table_name, cat_features, |
| ordered_cat_features, con_features, n_bins, |
| dep_var_str, boolean_cats, n_rows, is_classification, |
| dep_n_levels, filter_dep, null_proxy) |
| # some features may be dropped if they have only one value |
| cat_features = bins['cat_features'] |
| if not cat_features and not con_features: |
| plpy.error("Decision tree: None of the input features are valid") |
| _create_cat_features_info_table(cat_features_info_table, bins) |
| |
| # 4) Run tree train till the training is finished |
| # finished: 0 = running, 1 = finished training, 2 = terminated prematurely |
| tree = _tree_train_using_bins(**locals()) |
| tree['grp_key'] = '' |
| tree['cp'] = grp_key_to_cp[tree['grp_key']] |
| tree_states = [tree] |
| else: |
| grouping_array_str = get_grouping_array_str(training_table_name, grouping_cols) |
| with OptimizerControl(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with HashaggControl(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| plpy.notice("Analyzing data to compute split boundaries for variables") |
| bins = _get_bins_grps(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, n_bins, |
| dep_var_str, |
| boolean_cats, grouping_cols, |
| grouping_array_str, n_rows, |
| is_classification, dep_n_levels, |
| filter_dep, null_proxy) |
| cat_features = bins['cat_features'] |
| if not cat_features and not con_features: |
| plpy.error("Decision tree: None of the input features " |
| "are valid for some groups") |
| _create_cat_features_info_table(cat_features_info_table, bins) |
| |
| # 3b) Load each group's tree state in memory and set to the initial tree |
| tree_states = _tree_train_grps_using_bins(**locals()) |
| for tree in tree_states: |
| grp_key = tree['grp_key'] |
| if len(grp_key_to_cp.values()) == 1: |
| # for train w/out CV, the cp value remains the same for |
| # all groups. This is passed as a single-element list. |
| tree['cp'] = grp_key_to_cp.values()[0] |
| else: |
| tree['cp'] = grp_key_to_cp[grp_key] |
| |
| # 5) prune the tree using provided 'cp' value and produce a list of |
| # cp values if cross-validation is required (cp_list = [] if not) |
| for tree in tree_states: |
| if 'cp' in tree: |
| pruned_tree = _prune_and_cplist(schema_madlib, tree, |
| tree['cp'], |
| compute_cp_list=compute_cp_list) |
| tree['tree_state'] = pruned_tree['tree_state'] |
| if 'pruned_depth' in pruned_tree: |
| tree['pruned_depth'] = pruned_tree['pruned_depth'] |
| else: |
| tree['pruned_depth'] = pruned_tree['tree_depth'] |
| if 'cp_list' in pruned_tree: |
| tree['cp_list'] = pruned_tree['cp_list'] |
| |
| importance_vectors = _compute_var_importance( |
| schema_madlib, tree, |
| len(cat_features), len(con_features)) |
| tree.update(**importance_vectors) |
| return tree_states, bins, dep_list, n_rows, cat_features_info_table |
| # ------------------------------------------------------------------------- |
| |
| |
| def get_grouping_array_str(table_name, grouping_cols, qualifier=None): |
| """ |
| Args: |
| @param grouping_cols: list, List of columns used as grouping columns |
| """ |
| def _col_to_text(col, col_type): |
| qualifier_str = qualifier + "." if qualifier else '' |
| if is_psql_boolean_type(col_type): |
| return "(case when {0} then 'true' else 'false' end)::text".format(col) |
| else: |
| return '({0}{1})::text'.format(qualifier_str, col) |
| |
| grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] |
| grouping_cols_and_types = [(c, get_expr_type(c, table_name)) |
| for c in grouping_cols_list] |
| grouping_array_str = "array_to_string(array[{0}]::text[], ',')".format( |
| ','.join(_col_to_text(col, col_type) for col, col_type in grouping_cols_and_types)) |
| return grouping_array_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| # XXX This function is used only in decision tree cross-validation. |
| # So it has some specific code for cv. |
| def _build_tree(schema_madlib, is_classification, split_criterion, |
| training_table_name, output_table_name, id_col_name, |
| dependent_variable, dep_is_bool, list_of_features, |
| cat_features, ordered_cat_features, |
| boolean_cats, con_features, grouping_cols, |
| weights, max_depth, min_split, min_bucket, n_bins, |
| cp_table, max_n_surr=0, null_proxy=None, |
| msg_level="warning", n_folds=0, **kwargs): |
| |
| compute_cp_list = False |
| if grouping_cols: |
| grouping_array_str = get_grouping_array_str(cp_table, grouping_cols) |
| else: |
| grouping_array_str = "''::TEXT" |
| sql = """SELECT |
| {0} AS grp_key, |
| explore_value as cp_val |
| FROM {1} |
| """.format(grouping_array_str, cp_table) |
| grp_cp_values = plpy.execute(sql) |
| grp_key_to_cp = dict((row['grp_key'], row['cp_val']) for row in grp_cp_values) |
| |
| with MinWarning(msg_level): |
| plpy.notice("Building tree for cross validation") |
| tree_states, bins, dep_list, n_rows, cat_features_info_table = \ |
| _get_tree_states(**locals()) |
| all_cols_types = dict([(f, get_expr_type(f, training_table_name)) |
| for f in cat_features + con_features]) |
| |
| n_all_rows = plpy.execute("select count(*) from " + training_table_name |
| )[0]['count'] |
| cp = grp_key_to_cp.values()[0] |
| |
| # _create_output_tables(...) is a general function |
| # we need to let it know that right now it is called from CV. |
| use_existing_tables = table_exists(output_table_name) # create tables if it does not exist |
| running_cv = True # flag to indicate that cv fold ID needs to be included |
| _create_output_tables(list_of_features_to_exclude='', |
| **locals()) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def tree_train(schema_madlib, training_table_name, output_table_name, |
| id_col_name, dependent_variable, list_of_features, |
| list_of_features_to_exclude, split_criterion, |
| grouping_cols, weights, max_depth, |
| min_split, min_bucket, n_bins, pruning_params, |
| null_handling_params, verbose_mode, **kwargs): |
| """ Decision tree main training function |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param training_table_name: str, source table name |
| @param output_table_name: str, model table name |
| @param id_col_name: str, id column name to uniquely identify each row |
| @param dependent_variable: str, dependent variable column name |
| @param list_of_features: str, Comma-separated list of feature column names, |
| can also be '*' implying all columns |
| except dependent_variable |
| @param list_of_features_to_exclude: str, features to exclude if '*' is used |
| @param split_criterion: str, Impurity function to use for splitting: |
| Classification: {'gini', 'entropy', 'misclass'} |
| Regression: 'mse' |
| @param grouping_cols: str, List of grouping columns to group the data |
| @param weights: str, Column name for weight for each tuple |
| @param max_depth: int, Maximum depth of the tree |
| @param min_split: int, Minimum tuples in a node before splitting it |
| @param min_bucket: int, Minimum tuples in each child before splitting a node |
| @param n_bins: int, Number of bins for quantizing a continuous variables |
| @param pruning_params: str, Parameters used for pruning the tree |
| cp = Cost-complexity for pruning |
| n_folds = Number of folds for cross-validation |
| @param verbose_mode: str, Verbosity of output messages |
| """ |
| msg_level = "notice" if verbose_mode else "warning" |
| |
| # Set default values for all arguments |
| split_criterion = 'gini' if not split_criterion else split_criterion |
| max_depth = 7 if max_depth is None else max_depth |
| if min_split is None and min_bucket is None: |
| min_split = 20 |
| min_bucket = 6 |
| else: |
| min_bucket = min_split // 3 if min_bucket is None else min_bucket |
| min_split = min_bucket * 3 if min_split is None else min_split |
| n_bins = 20 if n_bins is None else n_bins |
| |
| # defaults for cp and n_folds set within _extract_pruning_params |
| pruning_param_dict = _extract_pruning_params(pruning_params) |
| cp = pruning_param_dict['cp'] |
| n_folds = pruning_param_dict['n_folds'] |
| |
| # null handing parameters: max_n_surr and null_as_category |
| null_handling_dict = extract_keyvalue_params( |
| null_handling_params, |
| dict(max_surrogates=int, null_as_category=bool), |
| dict(max_surrogates=0, null_as_category=False)) |
| max_n_surr = null_handling_dict['max_surrogates'] |
| null_as_category = null_handling_dict['null_as_category'] |
| null_proxy = "__NULL__" if null_as_category else None |
| if null_as_category: |
| # can't have two ways of handling tuples with NULL values |
| max_n_surr = 0 |
| _assert(max_n_surr >= 0, |
| "Maximum number of surrogates ({0}) should be non-negative". |
| format(max_n_surr)) |
| |
| with MinWarning(msg_level): |
| # 1) |
| if not grouping_cols or not grouping_cols.strip(): |
| grouping_cols = '' |
| _tree_validate_args(**locals()) |
| # weights column has be to validated to be present, |
| # hence default value allocation happens after the validation |
| weights = '1' if not weights or not weights.strip() else weights.strip() |
| |
| # expand "*" syntax and exclude some features |
| features = _get_features_to_use(**locals()) |
| _assert(bool(features), |
| "Decision tree error: No feature is selected for the model.") |
| |
| # 2) |
| all_cols_types = dict([(f, get_expr_type(f, training_table_name)) |
| for f in features]) |
| cat_features, ordered_cat_features, boolean_cats, con_features = \ |
| _classify_features(all_cols_types, features) |
| |
| # assert that the continuous and categorical features together |
| # cover all features |
| invalid_features = set(features) - (set(cat_features) | set(con_features)) |
| _assert(not invalid_features, |
| "DT error: Some of the features are invalid ({0})". |
| format(invalid_features)) |
| |
| # get all rows |
| n_all_rows = plpy.execute("SELECT count(*) FROM {source_table}". |
| format(source_table=training_table_name) |
| )[0]['count'] |
| |
| is_classification, dep_is_bool = _is_dep_categorical( |
| training_table_name, dependent_variable) |
| split_criterion = _validate_split_criterion(split_criterion, is_classification) |
| |
| # 4) Build the tree with provided cp value |
| compute_cp_list = (n_folds > 1) |
| grp_key_to_cp = {'': cp} |
| # main training function to get trained decision trees |
| plpy.notice("Getting initial tree") |
| tree_states, bins, dep_list, n_rows, cat_features_info_table = \ |
| _get_tree_states(**locals()) |
| |
| # 5) Perform cross-validation to compute the lowest cp |
| dep_n_levels = len(dep_list) if dep_list else 1 |
| if n_folds > 1: |
| plpy.notice("Running cross validation") |
| _xvalidate(**locals()) |
| |
| plpy.notice("Creating output tables") |
| _create_output_tables(**locals()) |
| return None |
| # ------------------------------------------------------------ |
| |
| |
| def _create_output_tables(schema_madlib, training_table_name, output_table_name, |
| tree_states, bins, split_criterion, |
| id_col_name, dependent_variable, list_of_features, |
| list_of_features_to_exclude, |
| is_classification, n_all_rows, n_rows, dep_list, cp, |
| all_cols_types, cat_features_info_table, |
| grouping_cols=None, |
| use_existing_tables=False, running_cv=False, |
| n_folds=0, null_proxy=None, **kwargs): |
| if not grouping_cols: |
| _create_result_table(schema_madlib, tree_states[0], |
| bins['cat_origin'], bins['cat_n'], bins['cat_features'], |
| bins['con_features'], output_table_name, |
| use_existing_tables, running_cv, n_folds) |
| else: |
| _create_grp_result_table( |
| schema_madlib, tree_states, bins, bins['cat_features'], |
| bins['con_features'], output_table_name, cat_features_info_table, |
| grouping_cols, training_table_name, use_existing_tables, |
| running_cv, n_folds) |
| |
| failed_groups = sum(row['finished'] != 1 for row in tree_states) |
| _create_summary_table( |
| schema_madlib, split_criterion, training_table_name, |
| output_table_name, id_col_name, bins['cat_features'], bins['con_features'], |
| dependent_variable, list_of_features, list_of_features_to_exclude, |
| failed_groups, is_classification, n_all_rows, |
| n_rows, dep_list, all_cols_types, cp, grouping_cols, 1, |
| use_existing_tables, n_folds, null_proxy) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_n_and_deplist(training_table_name, dependent_variable, filter_null): |
| """ |
| @brief Query the database for the total number of rows and |
| levels of dependent variable if the dependent variable is |
| categorical. |
| |
| Note: The deplist is sorted in the array_agg which is necessary to ensure |
| false = 0 and true = 1 for boolean dependent variable. |
| """ |
| sql = """ |
| SELECT |
| sum(n_rows) as n_rows, |
| array_agg(dep ORDER BY dep) as dep |
| FROM ( |
| SELECT |
| count(*) as n_rows, |
| {dep_var} as dep |
| FROM |
| {source} |
| WHERE {filter_null} |
| GROUP BY ({dep_var}) |
| ) s |
| """.format(dep_var=dependent_variable, |
| source=training_table_name, |
| filter_null=filter_null) |
| r = plpy.execute(sql)[0] |
| r['n_rows'] = 0 if r['n_rows'] is None else r['n_rows'] |
| return long(r['n_rows']), r['dep'] |
| # ------------------------------------------------------------ |
| |
| |
| def _is_dep_categorical(training_table_name, dependent_variable): |
| """ |
| @brief Sample the dependent variable to check whether it is |
| a categorical variable. |
| """ |
| sample_dep = get_expr_type(dependent_variable, training_table_name) |
| is_dep_numeric = is_psql_numeric_type(sample_dep, |
| exclude=['smallint', |
| 'integer', |
| 'bigint']) |
| is_dep_bool = is_psql_boolean_type(sample_dep) |
| return (not is_dep_numeric, is_dep_bool) |
| # ------------------------------------------------------------ |
| |
| |
| def _get_bins(schema_madlib, training_table_name, |
| cat_features, ordered_cat_features, |
| con_features, n_bins, dependent_variable, boolean_cats, n_rows, |
| is_classification, dep_n_levels, filter_null_dep, |
| null_proxy=None): |
| """ Compute the bins of all features |
| |
| @param training_table_name Data source table |
| @param cat_features A list of strings, categorical column names |
| @param con_features A list of strings, continuous column names |
| @param n_bins Number of splits equals n_bins - 1 |
| @param dependent_variable Will be needed when sorting the levels of |
| categorical variables |
| @param boolean_cats The categorical variables that are of boolean type |
| |
| return one dictionary containing two arrays: categorical and continuous |
| """ |
| if len(con_features) > 0: |
| if n_bins > n_rows: |
| plpy.error("Decision tree error: Number of bins is larger than " |
| "the number of data records.") |
| sample_size = n_bins * n_bins # We use Spark's value here |
| # FIXME Hard coded number |
| if sample_size < 10000: |
| sample_size = 10000 |
| |
| # Compute the percentage of the sample. |
| # We sample a few more values to make sure we can get enough |
| # samples, otherwise the number of samples might be smaller |
| # than sample_size. |
| # Use design doc Eq. (2.1) |
| actual_sample_size = sample_size + 14 + sqrt(196 + 28 * sample_size) |
| if actual_sample_size > n_rows: |
| actual_sample_size = n_rows |
| percentage = actual_sample_size / n_rows |
| |
| # For continuous variables, use one function to compute |
| # the splits for all of them. Similar to the existing |
| # _compute_splits function in CoxPH module, but deal with |
| # multiple columns together. |
| con_features_str = py_list_to_sql_string(con_features, "double precision") |
| |
| sample_table_name = unique_string() |
| plpy.execute(""" |
| CREATE TEMP TABLE {sample_table_name} AS |
| SELECT * |
| FROM ( |
| SELECT *, random() AS rand |
| FROM {training_table_name} |
| WHERE {filter_null_dep} |
| AND not {schema_madlib}.array_contains_null({con_features_str}) |
| ) subq |
| WHERE rand <= {percentage} |
| """.format(**locals())) |
| |
| # The splits for continuous variables |
| con_split_str = ("""{schema_madlib}._dst_compute_con_splits( |
| {con_features_str}, |
| {actual_sample_size}::integer, |
| {n_bins}::smallint)""". |
| format(**locals())) |
| con_splits = plpy.execute(""" |
| SELECT {con_split_str} as con_splits |
| FROM {sample_table_name} |
| """.format(**locals()))[0] |
| |
| plpy.execute("DROP TABLE {sample_table_name}".format(**locals())) |
| else: |
| con_splits = {'con_splits': ''} # no continuous features present |
| |
| # For categorical variables, scan the whole table to extract all the |
| # levels of the categorical variables, and at the same time |
| # sort the levels according to the entropy of the dependent |
| # variable. |
| # So this aggregate returns a composite type with two columns: |
| # col 1 is the array of ordered levels; col 2 is the number of |
| # levels in col 1. |
| |
| # TODO When n_bins is larger than 2^k - 1, where k is the number |
| # of levels of a given categrical feature, we can actually compute |
| # all combinations of levels and obtain a complete set of splits |
| # instead of using sorting to get an approximate set of splits. |
| # |
| # We will use integer to represent levels of categorical variables. |
| # So before everything, we need to create a mapping from categorical |
| # variable levels to integers, and keep this mapping in the memory. |
| if len(cat_features) > 0: |
| if is_classification: |
| # For classification the dependent variable is encoded as an integer |
| order_fun = ("{madlib}._dst_compute_entropy({dep}, {n})". |
| format(madlib=schema_madlib, |
| dep=dependent_variable, |
| n=dep_n_levels)) |
| else: |
| # For regressions |
| order_fun = "AVG({0})".format(dependent_variable) |
| |
| # Note that 'sql_cat_levels' goes through two levels of string formatting |
| # Try to obtain all the levels in one scan of the table. |
| # () are needed when casting the categorical variables because |
| # they can be expressions. |
| |
| filter_str = filter_null_dep + " AND {col} IS NOT NULL" |
| |
| if null_proxy is None: |
| union_null_proxy = "" |
| else: |
| union_null_proxy = "UNION " + """ |
| SELECT '{null_proxy}'::text as levels, |
| 'Infinity'::double precision as dep_avg |
| FROM {training_table_name} |
| WHERE {{col}} IS NULL |
| GROUP BY {{col}} |
| """.format(**locals()) |
| |
| sql_cat_levels = """ |
| SELECT |
| '{{col_name}}'::text AS colname, |
| levels |
| FROM ( |
| SELECT |
| '{{col_name}}'::text AS colname, |
| array_agg(levels ORDER BY dep_avg) AS levels |
| FROM ( |
| SELECT |
| ({{col}})::text AS levels, |
| {{order_fun}} AS dep_avg |
| FROM {training_table_name} |
| WHERE {filter_str} |
| GROUP BY {{col}} |
| {union_null_proxy} |
| ) s |
| ) s1 |
| """.format(training_table_name=training_table_name, |
| filter_str=filter_str, |
| union_null_proxy=union_null_proxy) |
| |
| all_col_expressions = dict(zip(cat_features, explicit_bool_to_text( |
| training_table_name, cat_features, schema_madlib))) |
| |
| sql_all_cats = ' UNION '.join( |
| sql_cat_levels.format( |
| col=expr, |
| col_name=col_name, |
| order_fun=expr if col_name in ordered_cat_features else order_fun) |
| for col_name, expr in all_col_expressions.items()) |
| all_levels = plpy.execute(sql_all_cats) |
| col_to_row = dict((row['colname'], i) for i, row in enumerate(all_levels)) |
| |
| return dict( |
| con=con_splits['con_splits'], |
| con_features=con_features, |
| cat_origin=[level for col in cat_features |
| for level in all_levels[col_to_row[col]]['levels']], |
| cat_n=[len(all_levels[col_to_row[col]]['levels']) |
| for col in cat_features], |
| cat_features=cat_features) |
| else: |
| # categorical part is empty |
| return dict( |
| con=con_splits['con_splits'], |
| con_features=con_features, |
| cat_origin=[], |
| cat_n=[], |
| cat_features=[]) |
| # ------------------------------------------------------------ |
| |
| |
| def _create_result_table(schema_madlib, tree_state, |
| cat_origin, cat_n, cat_features, |
| con_features, output_table_name, |
| use_existing_tables=False, running_cv=False, |
| k=0): |
| """ Create the output table and the summary table |
| |
| In the result table, we need the tree_state and also the categorical |
| sorted levels, which will be used in the printing and the prediction |
| functions. |
| """ |
| fold = ", " + str(k) + " as k" if running_cv else "" |
| if use_existing_tables: |
| # plpy.execute("truncate " + output_table_name) |
| header = "insert into " + output_table_name + " " |
| else: |
| header = "create table " + output_table_name + " as " |
| depth = (tree_state['pruned_depth'] if 'pruned_depth' in tree_state |
| else tree_state['tree_depth']) |
| |
| if len(cat_features) > 0: |
| sql = header + """ |
| SELECT |
| {cp} AS pruning_cp, |
| $1 AS tree, |
| $2 AS cat_levels_in_text, |
| $3 AS cat_n_levels, |
| $4 AS impurity_var_importance, |
| {depth} AS tree_depth |
| {fold} |
| """.format(depth=depth, |
| cp=tree_state['cp'], |
| fold=fold) |
| sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib), |
| 'text[]', |
| 'integer[]', |
| 'double precision[]']) |
| plpy.execute(sql_plan, [tree_state['tree_state'], |
| cat_origin, |
| cat_n, |
| tree_state['impurity_var_importance']]) |
| else: |
| sql = header + """ |
| SELECT |
| {cp} AS pruning_cp, |
| $1 AS tree, |
| NULL::text[] AS cat_levels_in_text, |
| NULL::integer[] AS cat_n_levels, |
| $2 AS impurity_var_importance, |
| {depth} AS tree_depth |
| {fold} |
| """.format(depth=depth, |
| cp=tree_state['cp'], |
| fold=fold) |
| sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib), |
| 'double precision[]']) |
| plpy.execute(sql_plan, [tree_state['tree_state'], |
| tree_state['impurity_var_importance']]) |
| |
| # ------------------------------------------------------------ |
| |
| |
| def _get_bins_grps( |
| schema_madlib, training_table_name, cat_features, ordered_cat_features, |
| con_features, n_bins, dependent_variable, boolean_cats, |
| grouping_cols, grouping_array_str, n_rows, is_classification, |
| dep_n_levels, filter_null_dep, null_proxy=None): |
| """ Compute the bins for all features in each group |
| |
| @brief Similar to _get_bins except that this is for multiple groups. |
| So please refer to the comments inside _get_bins for more related |
| information. |
| |
| returns what _one_step_for_grps needs: |
| """ |
| if len(con_features) > 0: |
| if n_bins > n_rows: |
| plpy.error("Decision tree error: Number of bins is larger than " |
| "the number of data records.") |
| sample_size = n_bins * n_bins # We use Spark's value here |
| # FIXME Hard coded number |
| if sample_size < 10000: |
| sample_size = 10000 |
| |
| # Compute the percentage of the sample. |
| # We sample a few more values to make sure we can get enough |
| # samples, otherwise the number of samples might be smaller |
| # than sample_size. |
| # Use design doc Eq. (2.1) |
| # FIXME Should also use this for the CoxPH module |
| sample_size = sample_size + 14 + sqrt(196 + 28 * sample_size) |
| |
| grp_key_str = unique_string() |
| n_per_seg_str = unique_string() |
| grp_size_str = unique_string() |
| random_str = unique_string() |
| |
| con_features_str = py_list_to_sql_string(con_features, "double precision") |
| sample_table_name = unique_string() |
| sql = """ |
| CREATE TEMP TABLE {sample_table_name} AS |
| SELECT * |
| FROM |
| ( |
| SELECT |
| *, |
| random() AS rand, |
| CASE WHEN {grp_size_str} < {sample_size} |
| THEN {grp_size_str} |
| ELSE {sample_size} |
| END AS {n_per_seg_str} |
| FROM |
| ( |
| SELECT * |
| FROM {training_table_name} |
| WHERE {filter_null_dep} |
| AND not {schema_madlib}.array_contains_null({con_features_str}) |
| ) src |
| JOIN |
| ( |
| SELECT {grouping_cols}, count(*) AS {grp_size_str} |
| FROM {training_table_name} |
| WHERE {filter_null_dep} |
| AND not {schema_madlib}.array_contains_null({con_features_str}) |
| GROUP BY {grouping_cols} |
| ) grp_info |
| USING ({grouping_cols}) |
| ) subq |
| WHERE rand * {grp_size_str} <= {sample_size} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # splits is a list, each of whose elements is a dictionary. |
| # The dictionary contains 2 items: |
| # 1) grp_key - It is an array of text |
| # 2) con_splits - continuous split array |
| con_split_str = """{schema_madlib}._dst_compute_con_splits( |
| {con_features_str}, |
| {n_per_seg}::integer, |
| {n_bins}::smallint)""".format(con_features_str=con_features_str, |
| schema_madlib=schema_madlib, |
| n_per_seg=n_per_seg_str, |
| n_bins=n_bins) |
| con_splits_all = plpy.execute( |
| """ SELECT |
| {con_split_str} AS con_splits, |
| {grouping_array_str} AS grp_key |
| FROM {sample_table_name} |
| GROUP BY {grouping_cols} |
| """.format(**locals()) # multiple rows |
| ) |
| |
| plpy.execute("DROP TABLE {sample_table_name}".format(**locals())) |
| |
| if cat_features: |
| if is_classification: |
| # For classifications |
| order_fun = ("{schema_madlib}._dst_compute_entropy(" |
| "{dependent_variable}, {n})". |
| format(schema_madlib=schema_madlib, |
| dependent_variable=dependent_variable, |
| n=dep_n_levels)) |
| else: |
| order_fun = "avg({0})".format(dependent_variable) |
| |
| filter_str = filter_null_dep + " AND {col} IS NOT NULL" |
| if null_proxy is None: |
| union_null_proxy = "" |
| else: |
| union_null_proxy = "UNION " + """ |
| SELECT '{null_proxy}'::text as levels, |
| 'Infinity'::double precision as dep_avg |
| FROM {training_table_name} |
| WHERE {{col}} IS NULL |
| GROUP BY {{col}} |
| """.format(**locals()) |
| |
| sql_cat_levels = """ |
| SELECT |
| colname::text, |
| levels::text[], |
| grp_key::text |
| from ( |
| SELECT |
| grp_key, |
| '{{col_name}}' as colname, |
| array_agg(levels order by dep_avg) as levels |
| from ( |
| SELECT |
| {grouping_array_str} as grp_key, |
| ({{col}})::text as levels, |
| {{order_fun}} as dep_avg |
| FROM {training_table_name} |
| WHERE {filter_str} |
| GROUP BY {{col}}, {grouping_cols} |
| ) s |
| GROUP BY grp_key |
| ) s1 |
| """.format(**locals()) |
| |
| all_col_expressions = dict(zip(cat_features, explicit_bool_to_text( |
| training_table_name, cat_features, schema_madlib))) |
| sql_all_cats = ' UNION ALL '.join( |
| sql_cat_levels.format( |
| col=expr, |
| col_name=col_name, |
| order_fun=expr if col_name in ordered_cat_features else order_fun) |
| for col_name, expr in all_col_expressions.items()) |
| |
| all_levels = list(plpy.execute(sql_all_cats)) |
| all_levels.sort(key=itemgetter('grp_key')) |
| |
| # grp_col_to_levels is a list of tuples (pairs) with |
| # first value = group value, |
| # second value = a dict mapping a categorical column to its levels in data |
| # (these levels are specific to the group and can be different |
| # for different groups) |
| # The list of tuples can be converted to a dict, but the ordering |
| # will be lost. |
| # eg. grp_to_col_to_levels = |
| # [ |
| # ('3', {'vs': [0, 1], 'cyl': [4,6,8]}), |
| # ('4', {'vs': [0, 1], 'cyl': [4,6]}), |
| # ('5', {'vs': [0, 1], 'cyl': [4,6,8]}) |
| # ] |
| grp_to_col_to_levels = [ |
| (grp_key, dict((row['colname'], row['levels']) for row in items)) |
| for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))] |
| grp_to_cat_features = dict([(g, col_to_levels.keys()) |
| for (g, col_to_levels) in grp_to_col_to_levels]) |
| # Below statements collect the grp_to_col_to_levels into multiple variables |
| # From above eg. |
| # cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]] |
| # cat_n = [2, 3, 2, 2, 2, 3] |
| # cat_origin = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8] |
| # grp_key_cat = ['3', '4', '5'] |
| cat_items_list = [rows[col] |
| for grp_key, rows in grp_to_col_to_levels |
| for col in cat_features if col in rows] |
| cat_n = [len(i) for i in cat_items_list] |
| cat_origin = [item for sublist in cat_items_list for item in sublist] |
| grp_key_cat = [item[0] for item in grp_to_col_to_levels] |
| else: |
| cat_n = [] |
| cat_origin = [] |
| grp_key_cat = [con_splits['grp_key'] for con_splits in con_splits_all] |
| grp_to_col_to_levels = [(con_splits['grp_key'], dict()) |
| for con_splits in con_splits_all] |
| grp_to_cat_features = dict([(con_splits['grp_key'], list()) |
| for con_splits in con_splits_all]) |
| |
| if con_features: |
| con = [con_splits['con_splits'] for con_splits in con_splits_all] |
| grp_key_con = [con_splits['grp_key'] for con_splits in con_splits_all] |
| else: |
| grp_key_con = [grp_key[0] for grp_key in grp_to_col_to_levels] |
| con = [''] * len(grp_key_con) |
| |
| return dict(con=con, |
| con_features=con_features, |
| grp_key_con=grp_key_con, |
| cat_origin=cat_origin, |
| cat_n=cat_n, |
| cat_features=cat_features, |
| grp_key_cat=grp_key_cat, |
| grouping_array_str=grouping_array_str, |
| grp_to_col_to_levels=grp_to_col_to_levels, |
| ) |
| # ------------------------------------------------------------ |
| |
| |
| def _create_cat_features_info_table(cat_features_info_table, bins): |
| # bins['grp_to_col_to_levels'] = |
| # [ |
| # ('3', {'vs': [0, 1], 'cyl': [4,6,8]}), |
| # ('4', {'vs': [0, 1], 'cyl': [4,6]}), |
| # ('5', {'vs': [0, 1]}) |
| # ] |
| # Convert this into a VALUES command and place in a table |
| # VALUES (('3', ARRAY[2, 3], ARRAY['0', '1', '4', '6', '8']), |
| # ('4', ARRAY[2, 2], ARRAY['0', '1', '4', '6']), |
| # ('5', ARRAY[2], ARRAY['0', '1']), |
| # ) |
| cat_features_info_values = [] |
| if 'grp_to_col_to_levels' in bins: |
| # Grouping enabled, implies the cat levels can be different for |
| # different groups |
| for i, (grp_key, col_to_levels) in enumerate(bins['grp_to_col_to_levels'], start=1): |
| grp_key_str = quote_literal(grp_key) |
| cat_names_levels = [(c, col_to_levels[c]) for c in bins['cat_features'] |
| if c in col_to_levels] |
| if cat_names_levels: |
| cat_names, cat_levels = zip(*cat_names_levels) |
| # categorical features in current group (expressed in an array) |
| cat_names_str = py_list_to_sql_string( |
| map(quote_literal, cat_names), 'text', long_format=True) |
| # number of levels in each cat feature |
| cat_n_levels_str = py_list_to_sql_string( |
| map(len, cat_levels), 'integer', long_format=True) |
| # flatten the levels across all cat features |
| cat_levels = [quote_literal(each_level) |
| for sublist in cat_levels |
| for each_level in sublist] |
| cat_levels_str = py_list_to_sql_string(cat_levels, |
| 'text', |
| long_format=True) |
| else: |
| # this is the case if no categorical features present |
| cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" |
| |
| cat_features_info_values.append( |
| "({i}::INTEGER, {grp_key_str}::TEXT, {cat_names_str}::TEXT[], " |
| "{cat_n_levels_str}::INTEGER[], {cat_levels_str}::TEXT[])". |
| format(**locals())) |
| else: |
| # no grouping |
| if bins['cat_features']: |
| cat_names_str = py_list_to_sql_string( |
| map(quote_literal, bins['cat_features']), 'text', long_format=True) |
| cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', long_format=True) |
| cat_levels_str = py_list_to_sql_string( |
| map(quote_literal, bins['cat_origin']), 'text', long_format=True) |
| else: |
| cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" |
| cat_features_info_values.append( |
| "(1::INTEGER, ''::TEXT, {0}::TEXT[], {1}::INTEGER[], {2}::TEXT[])". |
| format(cat_names_str, cat_n_levels_str, cat_levels_str)) |
| |
| sql_cat_features_info = """ |
| CREATE TEMP TABLE {0} AS |
| SELECT * |
| FROM ( |
| VALUES {1} |
| ) AS q(gid, grp_key, cat_names, cat_n_levels, cat_levels_in_text) |
| """.format(cat_features_info_table, |
| ',\n'.join(cat_features_info_values)) |
| plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info) |
| plpy.execute(sql_cat_features_info.format(**locals())) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_feature_str(schema_madlib, source_table, |
| cat_features, con_features, |
| levels_str, n_levels_str, |
| null_proxy=None): |
| if len(cat_features) > 0: |
| # null_val is the replacement for NULL in categorial feature. If a |
| # null_proxy is set then the proxy is used to assign NULL as a valid |
| # category. If no proxy is available then NULL is replaced with a unique |
| # value. In a later step, the categorical levels are mapped to integers |
| # (1 to N). The unique value will be mapped to -1 indicating an |
| # unknown/missing value in the underlying layers. |
| null_val = unique_string() if null_proxy is None else null_proxy |
| |
| # Cast boolean column to text: requires a special cast expression for |
| # platforms where __HAS_BOOL_TO_TEXT_CAST__ is not enabled |
| patched_cat_features = explicit_bool_to_text(source_table, |
| cat_features, |
| schema_madlib) |
| cat_features_cast = [] |
| for col in patched_cat_features: |
| cat_features_cast.append( |
| "(coalesce(({0})::text, '{1}'))::text".format(col, null_val)) |
| |
| cat_features_str = ("{0}._map_catlevel_to_int(array[" + |
| ", ".join(cat_features_cast) + "], {1}, {2}, {3})" |
| ).format(schema_madlib, |
| levels_str, |
| n_levels_str, |
| null_proxy is not None) |
| else: |
| cat_features_str = "NULL" |
| |
| if len(con_features) > 0: |
| con_features_list = ["COALESCE(" + str(con) + ", 'nan'::double precision)" |
| for con in con_features] |
| con_features_str = "ARRAY[" + ", ".join(con_features_list) + "]" |
| else: |
| con_features_str = "NULL" |
| return cat_features_str, con_features_str |
| # ------------------------------------------------------------------------- |
| |
| |
| def _one_step(schema_madlib, training_table_name, cat_features, |
| con_features, boolean_cats, bins, n_bins, tree_state, weights, |
| dep_var, min_split, min_bucket, max_depth, filter_null, |
| dep_n_levels, subsample, n_random_features, |
| max_n_surr=0, null_proxy=None): |
| """ One step of tree training |
| |
| @param tree_state A big double precision array that conatins |
| (1) internal node: column and split value |
| (2) leaf node: the statistics described as a series of numbers |
| """ |
| # The function _map_catlevel_to_int maps a categorical variable value to its |
| # integer representation. It returns an integer array. |
| # XXX cat_feature_str contains $5 and $2, and a SQL function |
| bytea8 = schema_madlib + '.bytea8' |
| cat_features_str, con_features_str = get_feature_str(schema_madlib, |
| training_table_name, |
| cat_features, |
| con_features, |
| "$3", "$2", |
| null_proxy) |
| |
| train_sql = """ |
| SELECT (result).* from ( |
| SELECT |
| {schema_madlib}._dt_apply( |
| $1, |
| {schema_madlib}._compute_leaf_stats( |
| $1, -- current tree state, madlib.bytea8 |
| {cat_features_str}, -- categorical features in an array |
| {con_features_str}, -- continuous features in an array |
| {dep_var}, |
| {weights}, -- weight value |
| $2, -- categorical sorted levels in a combined array |
| $4, -- continuous splits |
| {dep_n_levels}::smallint, -- number of dependent levels |
| {subsample}::boolean -- should we use a subsample of data |
| ), |
| $4, |
| {min_split}::smallint, |
| {min_bucket}::smallint, |
| {max_depth}::smallint, |
| {subsample}::boolean, |
| {n_random_features}::integer |
| ) as result |
| FROM {training_table_name} |
| WHERE {filter_null} |
| ) s |
| """.format(**locals()) |
| train_sql_plan = plpy.prepare(train_sql, [bytea8, 'integer[]', 'text[]', bytea8]) |
| # return a new tree state |
| updated_tree = plpy.execute(train_sql_plan, [tree_state['tree_state'], |
| bins['cat_n'], |
| bins['cat_origin'], |
| bins['con'] |
| ])[0] |
| # Compute surrogates: |
| # tree_depth outside the scope of dt_apply starts from 0 i.e. 0 depth |
| # implies a single leaf node for the tree. We compute surrogates only for |
| # internal nodes. Hence we at least need the root be an internal node i.e. |
| # we need the tree_depth to be 1 or more. |
| if max_n_surr > 0 and updated_tree['tree_depth'] > 0: |
| dup_count_expr = weights + "::integer" if subsample else '1' |
| surr_sql = """ |
| SELECT |
| {schema_madlib}._dt_surr_apply( |
| $1, |
| {schema_madlib}._compute_surr_stats( |
| $1, |
| {cat_features_str}, |
| {con_features_str}, |
| $2, |
| $4, |
| {dup_count_expr}), |
| $4) as result |
| FROM {training_table_name} |
| WHERE {filter_null} |
| """.format(**locals()) |
| surr_sql_plan = plpy.prepare(surr_sql, [bytea8, 'integer[]', 'text[]', bytea8]) |
| # return a new tree state |
| updated_tree['tree_state'] = plpy.execute(surr_sql_plan, |
| [updated_tree['tree_state'], |
| bins['cat_n'], |
| bins['cat_origin'], |
| bins['con'] |
| ])[0]["result"] |
| return updated_tree |
| # ------------------------------------------------------------ |
| |
| |
| def _one_step_for_grps( |
| schema_madlib, training_table_name, cat_features, |
| con_features, boolean_cats, bins, n_bins, tree_states, weights, |
| grouping_cols, grouping_array_str, dep_var, min_split, min_bucket, |
| max_depth, filter_null, dep_n_levels, subsample, n_random_features, |
| cat_features_info_table, max_n_surr=0, null_proxy=None): |
| """ One step of trees training with grouping support |
| """ |
| # The function _map_catlevel_to_int maps a categorical variable value to its |
| # integer representation. It returns an integer array. |
| # XXX cat_feature_str contains $5 and $2, and a SQL function |
| |
| bytea8arr = schema_madlib + '.bytea8[]' |
| # avoid name conflicts |
| grp_key = unique_string() |
| tree_state = unique_string() |
| con_splits = unique_string() |
| cat_n_levels = unique_string() |
| finished = unique_string() |
| cat_levels_in_text = unique_string() |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, training_table_name, cat_features, con_features, |
| cat_levels_in_text, cat_n_levels, null_proxy) |
| |
| train_apply_func = """ |
| {schema_madlib}._dt_apply( |
| {tree_state}, |
| agg_result, |
| {con_splits}, |
| {min_split}::smallint, |
| {min_bucket}::smallint, |
| {max_depth}::smallint, |
| {subsample}::boolean, |
| {n_random_features}::integer) |
| """.format(**locals()) |
| |
| train_aggregate = """ |
| {schema_madlib}._compute_leaf_stats( |
| {tree_state}, |
| {cat_features_str}, |
| {con_features_str}, |
| {dep_var}, |
| {weights}, |
| {cat_n_levels}, |
| {con_splits}, |
| {dep_n_levels}::smallint, |
| {subsample}::boolean) |
| """.format(**locals()) |
| |
| sql = """ |
| SELECT |
| s1.grp_key, |
| {apply_func} as result |
| FROM ( SELECT |
| {grouping_array_str} AS grp_key, |
| {aggregate} AS agg_result |
| FROM |
| {training_table_name} as src, |
| ( SELECT |
| grp_key AS {grp_key}, |
| finished AS {finished}, |
| tree_state AS {tree_state}, |
| con_splits AS {con_splits}, |
| cat_n_levels::INTEGER[] AS {cat_n_levels}, |
| cat_levels_in_text::TEXT[] AS {cat_levels_in_text} |
| FROM |
| ( SELECT |
| unnest($1) AS grp_key, |
| unnest($2) AS finished, |
| unnest($3) AS tree_state |
| ) AS tree_state_set |
| JOIN ( |
| SELECT |
| unnest($4) AS grp_key, |
| unnest($5) AS con_splits |
| ) AS con_splits |
| USING (grp_key) |
| JOIN |
| {cat_features_info_table} |
| USING (grp_key) |
| ) AS needed_data |
| WHERE {grouping_array_str} = {grp_key} |
| AND {filter_null} |
| GROUP BY {grouping_cols} |
| ) s1 |
| JOIN ( SELECT |
| grp_key, |
| tree_state AS {tree_state}, |
| con_splits AS {con_splits} |
| FROM ( SELECT |
| unnest($1) AS grp_key, |
| unnest($3) AS tree_state |
| ) AS tree_state_set |
| JOIN |
| ( SELECT |
| unnest($4) AS grp_key, |
| unnest($5) AS con_splits |
| ) AS con_splits |
| USING (grp_key) |
| ) s2 |
| USING (grp_key) |
| """ |
| train_sql = "SELECT grp_key, (result).* FROM (" + sql + ") sub" |
| train_sql = train_sql.format(aggregate=train_aggregate, |
| apply_func=train_apply_func, |
| # check_finished="AND " + finished + " = 0", |
| **locals()) |
| train_sql_plan = plpy.prepare(train_sql, |
| ['text[]', 'integer[]', bytea8arr, |
| 'text[]', bytea8arr]) |
| |
| unfinished_trees = [t for t in tree_states if t['finished'] == 0] |
| finished_trees = [t for t in tree_states if t['finished'] != 0] |
| |
| # Need to cast the output of plpy.execute to 'list' since newer versions |
| # of Postgres return a PlyResult object instead of a list |
| updated_unfinished = list(plpy.execute(train_sql_plan, [ |
| [t['grp_key'] for t in unfinished_trees], |
| [t['finished'] for t in unfinished_trees], |
| [t['tree_state'] for t in unfinished_trees], |
| bins['grp_key_con'], |
| bins['con']])) |
| |
| if max_n_surr > 0: |
| surr_apply_func = """ |
| {schema_madlib}._dt_surr_apply( |
| {tree_state}, |
| agg_result, |
| {con_splits}) |
| """.format(schema_madlib=schema_madlib, |
| tree_state=tree_state, |
| con_splits=con_splits) |
| |
| dup_count_expr = weights + "::integer" if subsample else '1' |
| surr_aggregate = """ |
| {schema_madlib}._compute_surr_stats( |
| {tree_state}, |
| {cat_features_str}, |
| {con_features_str}, |
| {cat_n_levels}, |
| {con_splits}, |
| {dup_count_expr}) |
| """.format(**locals()) |
| |
| # update the tree to compute surrogates for the layer above the just |
| # computed leaves |
| surr_sql = sql.format(aggregate=surr_aggregate, |
| apply_func=surr_apply_func, |
| check_finished="", |
| **locals()) |
| surr_sql_plan = plpy.prepare(surr_sql, |
| ['text[]', 'integer[]', bytea8arr, |
| 'text[]', bytea8arr]) |
| surr_trees = list(plpy.execute(surr_sql_plan, [ |
| [t['grp_key'] for t in updated_unfinished], |
| [t['finished'] for t in updated_unfinished], |
| [t['tree_state'] for t in updated_unfinished], |
| bins['grp_key_con'], |
| bins['con']])) |
| |
| surr_dict = dict() |
| for each_tree in surr_trees: |
| surr_dict[each_tree['grp_key']] = each_tree['result'] |
| for each_tree in updated_unfinished: |
| if each_tree['grp_key'] in surr_dict: |
| each_tree['tree_state'] = surr_dict[each_tree['grp_key']] |
| |
| # return all tree states (finished and unfinished) |
| updated_unfinished.extend(finished_trees) |
| |
| return updated_unfinished |
| # ------------------------------------------------------------ |
| |
| |
| def _create_grp_result_table( |
| schema_madlib, tree_states, bins, cat_features, |
| con_features, output_table_name, cat_features_info_table, |
| grouping_cols, |
| training_table_name, use_existing_tables=False, |
| running_cv=False, k=0): |
| """ Create the output table for grouping case. |
| """ |
| |
| grp_key = unique_string() |
| tree_state = unique_string() |
| tree_depth = unique_string() |
| cp_col = unique_string() |
| cat_levels_in_text = unique_string() |
| cat_n_levels = unique_string() |
| impurity_var_importance = unique_string() |
| cat_levels_val = cat_levels_in_text if cat_features else "NULL::TEXT[]" |
| cat_n_levels_val = cat_n_levels if cat_features else "NULL::INTEGER[]" |
| grouping_array_str = bins['grouping_array_str'] |
| |
| fold = ", " + str(k) + " as k" if running_cv else "" |
| if use_existing_tables: |
| # plpy.execute("truncate " + output_table_name) |
| header = "insert into " + output_table_name + " " |
| else: |
| header = "create table " + output_table_name + " as " |
| sql = header + """ |
| SELECT |
| {grouping_cols}, |
| {tree_state} AS tree, |
| {cat_levels_val} AS cat_levels_in_text, |
| {cat_n_levels_val} AS cat_n_levels, |
| string_to_array(trim(both '{{}}' FROM {impurity_var_importance}), |
| ',')::double precision[] AS impurity_var_importance, |
| {tree_depth} AS tree_depth, |
| {cp_col} AS pruning_cp |
| {fold} |
| FROM ( |
| SELECT |
| {grouping_cols}, |
| {grouping_array_str} AS {grp_key} |
| FROM {training_table_name} |
| group by {grouping_cols} |
| ) s1 |
| JOIN ( |
| SELECT |
| unnest($1) AS {grp_key}, |
| unnest($2) AS {tree_state}, |
| unnest($3) AS {tree_depth}, |
| unnest($4) AS {cp_col}, |
| unnest($5) AS {impurity_var_importance} |
| ) s2 |
| USING ({grp_key}) |
| """ |
| if cat_features: |
| sql += """ |
| JOIN ( |
| SELECT |
| grp_key as {grp_key}, |
| cat_n_levels as {cat_n_levels}, |
| cat_levels_in_text as {cat_levels_in_text} |
| FROM |
| {cat_features_info_table} |
| ) s3 |
| USING ({grp_key}) |
| """ |
| sql = sql.format(**locals()) |
| prepare_list = ['text[]', |
| '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib), |
| 'integer[]', 'double precision[]', 'text[]'] |
| execute_list = [ |
| [t['grp_key'] for t in tree_states], |
| [t['tree_state'] for t in tree_states], |
| [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth'] |
| for t in tree_states], |
| [t['cp'] for t in tree_states], |
| [_array_to_string(t['impurity_var_importance']) for t in tree_states]] |
| if cat_features: |
| prepare_list += ['text[]', 'integer[]', 'integer', 'text[]'] |
| execute_list += [ |
| bins['grp_key_cat'], |
| bins['cat_n'], |
| len(cat_features), |
| bins['cat_origin']] |
| sql_plan = plpy.prepare(sql, prepare_list) |
| plpy.execute(sql_plan, execute_list) |
| # ------------------------------------------------------------ |
| |
| |
| def _get_dep_type(data_table, dep): |
| """ |
| @brief Obtain the dependent_variable type |
| """ |
| table_schema_str, table_name = _get_table_schema_names(data_table) |
| dep_type = plpy.execute(""" |
| SELECT data_type from information_schema.columns |
| where |
| table_name = '{table_name}' and |
| table_schema in {table_schema} and |
| column_name = '{dep}' |
| """.format(table_name=table_name, |
| table_schema=table_schema_str, |
| dep=dep)) |
| if dep_type: |
| return dep_type[0]['data_type'] |
| else: |
| return get_expr_type(dep, data_table) |
| # ------------------------------------------------------------ |
| |
| |
| def _create_summary_table( |
| schema_madlib, split_criterion, |
| training_table_name, output_table_name, id_col_name, |
| cat_features, con_features, dependent_variable, list_of_features, |
| list_of_features_to_exclude, |
| num_failed_groups, is_classification, n_all_rows, n_rows, |
| dep_list, all_cols_types, cp, grouping_cols=None, n_groups=1, |
| use_existing_tables=False, n_folds=0, null_proxy=None): |
| |
| output_table_summary = add_postfix(output_table_name, "_summary") |
| |
| # dependent variables |
| dep_type = _get_dep_type(training_table_name, dependent_variable) |
| if dep_list: |
| if is_psql_boolean_type(dep_type): |
| # Special handling for boolean since Python booleans start with |
| # capitals (i.e False instead of false) |
| # Note: dep_list is sorted, hence 'false' will be first |
| dep_list_str = "'false, true'" |
| else: |
| dep_list_str = '$__dep_list__${0}$__dep_list__$'.format( |
| ','.join(map(str, dep_list))) |
| else: |
| dep_list_str = "NULL" |
| indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features) |
| independent_varnames = ','.join(cat_features + con_features) |
| cat_features_str = ','.join(cat_features) |
| con_features_str = ','.join(con_features) |
| if grouping_cols: |
| grouping_cols_str = "'{0}'".format(grouping_cols) |
| else: |
| grouping_cols_str = "NULL" |
| n_rows_skipped = n_all_rows - n_rows |
| |
| if isinstance(cp, Iterable): |
| cp_str = py_list_to_sql_string(cp, 'double precision') |
| else: |
| cp_str = str(cp) + "::double precision" |
| |
| if use_existing_tables: |
| plpy.execute("TRUNCATE " + output_table_summary) |
| header = "INSERT INTO {0} ".format(output_table_summary) |
| else: |
| header = "CREATE TABLE {0} AS ".format(output_table_summary) |
| null_proxy_str="NULL" if null_proxy is None else "'{0}'".format(null_proxy) |
| sql = header + """ |
| SELECT |
| 'tree_train'::text AS method, |
| '{is_classification}'::boolean AS is_classification, |
| '{training_table_name}'::text AS source_table, |
| '{output_table_name}'::text AS model_table, |
| '{id_col_name}'::text AS id_col_name, |
| '{list_of_features}'::text AS list_of_features, |
| '{list_of_features_to_exclude}'::text AS list_of_features_to_exclude, |
| '{dependent_variable}'::text AS dependent_varname, |
| '{independent_varnames}'::text AS independent_varnames, |
| '{cat_features_str}'::text AS cat_features, |
| '{con_features_str}'::text AS con_features, |
| {grouping_cols_str}::text AS grouping_cols, |
| {n_groups}::integer AS num_all_groups, |
| {num_failed_groups}::integer AS num_failed_groups, |
| {n_rows}::integer AS total_rows_processed, |
| {n_rows_skipped}::integer AS total_rows_skipped, |
| {dep_list_str}::text AS dependent_var_levels, |
| '{dep_type}'::text AS dependent_var_type, |
| '{indep_type}'::text AS independent_var_types, |
| {cp_str} AS input_cp, |
| {n_folds}::integer AS n_folds, |
| {null_proxy_str}::text AS null_proxy |
| """.format(**locals()) |
| plpy.execute(sql) |
| # ------------------------------------------------------------ |
| |
| |
| def _get_filter_str(dependent_variable, grouping_cols): |
| """ Return a 'WHERE' clause string that filters out all rows that contain a |
| NULL. |
| """ |
| if grouping_cols: |
| group_filter = ' and '.join('({0}) is not NULL'.format(g.strip()) |
| for g in grouping_cols.split(',')) |
| else: |
| group_filter = None |
| dep_filter = '(' + dependent_variable + ") is not NULL" |
| return ' and '.join(filter(None, [group_filter, dep_filter])) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _validate_predict(schema_madlib, model, source, output, use_existing_tables): |
| # validations for inputs |
| _assert(source and source.strip().lower() not in ('null', ''), |
| "Decision tree error: Invalid data table name: {0}".format(source)) |
| _assert(table_exists(source), |
| "Decision tree error: Data table ({0}) does not exist".format(source)) |
| _assert(not table_is_empty(source), |
| "Decision tree error: Data table ({0}) is empty".format(source)) |
| _assert(model and |
| model.strip().lower() not in ('null', ''), |
| "Decision tree error: Invalid model table name: {0}".format(model)) |
| _assert(table_exists(model), |
| "Decision tree error: Model table ({0}) does not exist".format(model)) |
| _assert(not table_is_empty(model), |
| "Decision tree error: Model table ({0}) is empty".format(model)) |
| model_summary = add_postfix(model, "_summary") |
| _assert(table_exists(model_summary), |
| "Decision tree error: Model summary table ({0}) does not exist".format(model_summary)) |
| _assert(not table_is_empty(model_summary), |
| "Decision tree error: Model summary table ({0}) is empty".format(model_summary)) |
| _assert(output and |
| output.strip().lower() not in ('null', ''), |
| "Decision tree error: Invalid output table name: {0}".format(output)) |
| if not use_existing_tables: |
| _assert(not table_exists(output), |
| "Decision tree error: Output table ({0}) already exists".format(output)) |
| _assert( |
| columns_exist_in_table( |
| model, |
| ["tree", "cat_levels_in_text", "cat_n_levels"], |
| schema_madlib), |
| "Decision tree error: Invalid model table ({0})".format(model)) |
| _assert( |
| columns_exist_in_table( |
| model_summary, |
| ["grouping_cols", "id_col_name", "dependent_varname", |
| "cat_features", "con_features", "is_classification"], |
| schema_madlib), |
| "Decision tree error: Invalid model summary table ({0})".format(model_summary)) |
| # ------------------------------------------------------------------------- |
| |
| |
| def tree_predict(schema_madlib, model, source, output, pred_type='response', |
| use_existing_tables=False, k=0, **kwargs): |
| """ |
| Args: |
| @param schema_madlib: str, Name of MADlib schema |
| @param model: str, Name of table containing the tree model |
| @param source: str, Name of table containing prediction data |
| @param output: str, Name of table to output the results |
| @param pred_type: str, The type of output required: |
| 'response' gives the actual response values, |
| 'prob' gives the probability of the classes in a |
| classification tree. |
| For regression tree, only type='response' is defined. |
| Returns: |
| None |
| |
| Side effect: |
| Creates an output table containing the prediction for given source table |
| |
| Throws: |
| None |
| """ |
| _validate_predict(schema_madlib, model, source, output, use_existing_tables) |
| model_summary = add_postfix(model, "_summary") |
| |
| # obtain the cat_features and con_features from model summary table |
| summary_elements = plpy.execute("SELECT * FROM {0}".format(model_summary))[0] |
| |
| list_of_features = split_quoted_delimited_str(summary_elements["independent_varnames"]) |
| cat_features = split_quoted_delimited_str(summary_elements["cat_features"]) |
| con_features = split_quoted_delimited_str(summary_elements["con_features"]) |
| _assert(is_var_valid(source, ','.join(list_of_features)), |
| "Decision tree error: Missing columns in predict data table ({0}) " |
| "that were used during training".format(source)) |
| id_col_name = summary_elements["id_col_name"] |
| dep_varname = summary_elements["dependent_varname"] |
| dep_levels = split_quoted_delimited_str(summary_elements["dependent_var_levels"]) |
| dep_type = summary_elements['dependent_var_type'] |
| is_classification = summary_elements["is_classification"] |
| # optional variables, default value is None |
| grouping_cols_str = summary_elements.get("grouping_cols") |
| null_proxy = summary_elements.get('null_proxy') |
| |
| # find which columns are of type boolean |
| boolean_cats = set([key for key, value in get_cols_and_types(source) |
| if value == 'boolean']) |
| |
| cat_features_str, con_features_str = get_feature_str( |
| schema_madlib, source, cat_features, con_features, |
| "m.cat_levels_in_text", "m.cat_n_levels", null_proxy) |
| |
| if use_existing_tables and table_exists(output): |
| plpy.execute("TRUNCATE " + output) |
| header = "INSERT INTO " + output |
| use_fold = 'WHERE k = ' + str(k) |
| else: |
| header = "CREATE TABLE " + output + " AS " |
| use_fold = '' |
| |
| if not grouping_cols_str: |
| using_str = "" |
| join_str = "," |
| else: |
| using_str = "USING ( " + grouping_cols_str + ")" |
| join_str = "LEFT OUTER JOIN" |
| |
| pred_name = ('"prob_{0}"' if pred_type == "prob" else |
| '"estimated_{0}"').format(dep_varname.replace('"', '').strip()) |
| |
| if not is_classification: |
| sql = header + """ |
| SELECT {id_col_name}, |
| {schema_madlib}._predict_dt_response( |
| tree, |
| {cat_features_str}::INTEGER[], |
| {con_features_str}::DOUBLE PRECISION[]) as {pred_name} |
| FROM {source} as s |
| {join_str} |
| {model} as m |
| {using_str} |
| {use_fold} |
| """ |
| else: |
| if is_psql_boolean_type(dep_type): |
| # some platforms don't have text to boolean cast. We manually check the string. |
| dep_cast_str = ("(case {pred_name} when 'true' then true " |
| " when 'false' then false " |
| "end)::BOOLEAN as {pred_name}") |
| else: |
| dep_cast_str = "{pred_name}::{dep_type}" |
| dep_levels_array_str = py_list_to_sql_string(map(quote_literal, dep_levels), |
| 'TEXT', |
| long_format=True) |
| if pred_type == "response": |
| sql = header + """ |
| SELECT |
| {id_col_name}, |
| %s |
| FROM ( |
| SELECT |
| {id_col_name}, |
| -- _predict_dt_response returns 0-based indexing. |
| -- Hence the "+ 1" (DB by default uses 1-based indexing) |
| (%s)[{schema_madlib}._predict_dt_response ( |
| tree, |
| {cat_features_str}::INTEGER[], |
| {con_features_str}::DOUBLE PRECISION[]) + 1]::TEXT |
| as {pred_name} |
| FROM {source} as s {join_str} {model} as m {using_str} |
| {use_fold} |
| ) q |
| """ % (dep_cast_str, dep_levels_array_str) |
| else: |
| temp_col = unique_string() |
| score_format = ', \n'.join([ |
| '{0}[{1}] as "estimated_prob_{2}"'.format(temp_col, i, c.strip(' "')) |
| for i, c in enumerate(dep_levels, start=1)]) |
| sql = header + """ |
| SELECT |
| {id_col_name}, |
| %s |
| FROM ( |
| SELECT {id_col_name}, |
| {schema_madlib}._predict_dt_prob(tree, |
| {cat_features_str}::INTEGER[], |
| {con_features_str}::DOUBLE PRECISION[]) |
| AS {temp_col} |
| FROM {source} as s {join_str} {model} as m {using_str} |
| {use_fold} |
| ) q |
| """ % (score_format) |
| sql = sql.format(**locals()) |
| with MinWarning('warning'): |
| with OptimizerControl(False): |
| # we disable optimizer (ORCA) for platforms that use it |
| # since ORCA doesn't provide an easy way to disable hashagg |
| with HashaggControl(False): |
| # we disable hashagg since large number of groups could |
| # result in excessive memory usage. |
| plpy.execute(sql) |
| return None |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_display_header(table_name, dep_levels, is_regression, dot_format=True): |
| if dot_format: |
| tree_type = ("Classification", "Regression")[is_regression] |
| return ("digraph {0} {{".format('"' + tree_type + ' tree for ' + |
| str(table_name) + '"')) |
| else: |
| return_str = """------------------------------------- |
| - Each node represented by 'id' inside (). |
| - Each internal nodes has the split condition at the end, while each |
| leaf node has a * at the end. |
| - For each internal node (i), its child nodes are indented by 1 level |
| with ids (2i+1) for True node and (2i+2) for False node. |
| """ |
| if is_regression: |
| return_str += ("- Number of rows and average response value inside []. " |
| "For a leaf node, this is the prediction.\n") |
| else: |
| return_str += """- Number of (weighted) rows for each response variable inside [].' |
| The response label order is given as {0}. |
| For each leaf, the prediction is given after the '-->' |
| """.format(str(dep_levels)) |
| return_str += "\n-------------------------------------" |
| return return_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| def tree_display(schema_madlib, model_table, dot_format=True, verbose=False, |
| disp_surr=False, **kwargs): |
| |
| if dot_format: |
| disp_surr = False # surrogates cannot be displayed in dot format |
| bytea8 = schema_madlib + '.bytea8' |
| summary_table = add_postfix(model_table, "_summary") |
| summary = plpy.execute("SELECT * FROM {summary_table}". |
| format(summary_table=summary_table))[0] |
| dep_levels = summary["dependent_var_levels"] |
| dep_levels = [''] if not dep_levels else split_quoted_delimited_str(dep_levels) |
| table_name = summary["source_table"] |
| is_regression = not summary["is_classification"] |
| cat_features_str = split_quoted_delimited_str(summary['cat_features']) |
| con_features_str = split_quoted_delimited_str(summary['con_features']) |
| |
| grouping_cols = summary["grouping_cols"] |
| grouping_cols = '' if grouping_cols is None else grouping_cols + ',' |
| grouped_trees = plpy.execute("SELECT {grouping_cols} " |
| "tree, cat_levels_in_text, cat_n_levels " |
| "FROM {model_table}". |
| format(model_table=model_table, |
| grouping_cols=grouping_cols)) |
| |
| return_str_list = [] |
| if not disp_surr: |
| return_str_list.append(_get_display_header(table_name, dep_levels, |
| is_regression, dot_format)) |
| else: |
| return_str_list.append(""" |
| ------------------------------------- |
| Surrogates for internal nodes |
| ------------------------------------- |
| """) |
| |
| with MinWarning('warning'): |
| for index, dtree in enumerate(grouped_trees): |
| grouping_keys = [k + ' = ' + str(dtree[k]) for k in dtree.keys() |
| if k not in ('tree', 'cat_features', 'con_features', |
| 'cat_n_levels', 'cat_levels_in_text')] |
| if grouping_keys: |
| group_name = "(" + ','.join(grouping_keys) + ")" |
| else: |
| group_name = '' |
| tree = dtree['tree'] |
| if dtree['cat_levels_in_text']: |
| cat_levels_in_text = dtree['cat_levels_in_text'] |
| cat_n_levels = dtree['cat_n_levels'] |
| else: |
| cat_levels_in_text = [] |
| cat_n_levels = [] |
| |
| if disp_surr: |
| if group_name: |
| return_str_list.append("--- Surrogates for tree {0} ---".format(group_name)) |
| sql = """SELECT {0}._display_decision_tree_surrogate( |
| $1, $2, $3, $4, $5) as display_tree |
| """.format(schema_madlib) |
| # execute sql to get display string |
| sql_plan = plpy.prepare(sql, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]']) |
| tree_display = plpy.execute( |
| sql_plan, [tree, cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels])[0] |
| else: |
| if dot_format: |
| return_str_list.append('\t subgraph "cluster{0}"{{'.format(index)) |
| return_str_list.append('\t label="{0}"'.format(group_name.replace('"', '\\"'))) |
| sql = """ |
| SELECT {0}._display_decision_tree( |
| $1, $2, $3, $4, $5, $6, '{1}', {2} |
| ) as display_tree |
| """.format(schema_madlib, "g" + str(index) + "_", verbose) |
| else: |
| if group_name: |
| return_str_list.append("--- Tree for {0} ---".format(group_name)) |
| sql = """ |
| SELECT {0}._display_text_decision_tree( |
| $1, $2, $3, $4, $5, $6) as display_tree |
| """.format(schema_madlib) |
| sql_plan = plpy.prepare(sql, [bytea8, |
| 'text[]', 'text[]', 'text[]', |
| 'int[]', 'text[]']) |
| tree_display = plpy.execute( |
| sql_plan, [tree, cat_features_str, con_features_str, |
| cat_levels_in_text, cat_n_levels, dep_levels])[0] |
| |
| return_str_list.append(tree_display["display_tree"]) |
| if dot_format: |
| return_str_list.append("\t } //--- end of subgraph------------") |
| else: |
| return_str_list.append("-------------------------------------") |
| if dot_format: |
| return_str_list.append("} //---end of digraph--------- ") |
| return ("\n".join(return_str_list)) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False): |
| """ Prune tree with given cost-complexity parameters |
| and return a list of cp values at which tree can be pruned |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param tree: Tree data to prune |
| @param cp: float, cost-complexity parameter, all splits that have a |
| complexity lower than 'cp' will be pruned |
| @param compute_cp_list: bool, optionally return a list of cp values that |
| are the various complexity boundaries for the splits |
| in the tree. This list can be used by cross-validation |
| to explore different cp values. |
| |
| Returns: |
| Dictionary containing following keys: |
| tree_state: pruned tree state |
| pruned_depth: depth of tree after pruning |
| cp_list: list of cp values at which tree can be pruned |
| (returned only if compute_cp_list=True) |
| """ |
| sql = """ |
| SELECT (pruned_tree).* |
| FROM ( |
| SELECT {madlib}._prune_and_cplist( |
| $1, |
| ({cp})::double precision, |
| ({compute_cp_list})::boolean |
| ) as pruned_tree |
| ) q |
| """.format(madlib=schema_madlib, cp=cp, |
| compute_cp_list=bool(compute_cp_list)) |
| |
| sql_plan = plpy.prepare(sql, [schema_madlib + '.bytea8']) |
| pruned_tree = plpy.execute(sql_plan, [tree['tree_state']])[0] |
| return pruned_tree |
| # ------------------------------------------------------------------------- |
| |
| |
| def _compute_var_importance(schema_madlib, tree, |
| n_cat_features, n_con_features): |
| """ Compute variable importance for categorical and continuous features |
| |
| Args: |
| @param schema_madlib: str, MADlib schema name |
| @param tree: dict. tree['tree_state'] is the trained tree (in byte form) |
| @param n_cat_features: int, Number of categorical features |
| @param n_con_features: int, Number of continuous features |
| |
| Returns: |
| Dictionary containing following keys: |
| impurity_var_importance: Array of importance values |
| """ |
| var_imp_sql = """ |
| SELECT {schema_madlib}._compute_var_importance( |
| $1, -- trained decision tree |
| {n_cat_features}, |
| {n_con_features}) AS impurity_var_importance |
| """.format(**locals()) |
| var_imp_plan = plpy.prepare(var_imp_sql, [schema_madlib + '.bytea8']) |
| return plpy.execute(var_imp_plan, [tree['tree_state']])[0] |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_name, |
| id_col_name, dependent_variable, |
| list_of_features, list_of_features_to_exclude, |
| cat_features, ordered_cat_features, boolean_cats, con_features, |
| split_criterion, grouping_cols, weights, max_depth, |
| min_split, min_bucket, n_bins, is_classification, |
| dep_is_bool, dep_n_levels, n_folds, n_rows, |
| max_n_surr, null_proxy, msg_level='warning', **kwargs): |
| """ |
| Run cross validation for decision tree over multiple cp values |
| |
| Returns: |
| cp value corresponding to lowest cross validation error |
| |
| Side effect: |
| Creates a table containing cross-validation error for each cp value |
| """ |
| # 1) create group_to_param_list_table for CV |
| group_to_param_list_table = unique_string() |
| param_list_name = unique_string() |
| if not grouping_cols: |
| plan = plpy.prepare(""" |
| CREATE TEMP TABLE {group_to_param_list_table} AS |
| SELECT $1 AS {param_list_name} |
| """.format(**locals()), ["float8[]"]) |
| plpy.execute(plan, [tree_states[0]['cp_list']]) |
| plpy.notice("Running cross validation for cp values: {0}".format( |
| str(tree_states[0]['cp_list']))) |
| else: |
| grp_list = [] |
| cp_list = [] |
| plpy.notice("Running cross validation for cp values:") |
| for tree in tree_states: |
| grp_list.extend([tree['grp_key']]*len(tree['cp_list'])) |
| cp_list.extend(tree['cp_list']) |
| plpy.notice("{0} -> {1}".format(tree['grp_key'], str(tree['cp_list']))) |
| |
| grouping_array_str = get_grouping_array_str(training_table_name, |
| grouping_cols) |
| grp_key_str = unique_string() |
| plan = plpy.prepare(""" |
| CREATE TEMP TABLE {group_to_param_list_table} AS |
| SELECT |
| {grouping_cols}, |
| {param_list_name} |
| FROM |
| ( |
| SELECT {grp_key_str}, |
| array_agg(cp_value) as {param_list_name} |
| FROM |
| ( |
| SELECT unnest($1) as {grp_key_str}, |
| unnest($2) as cp_value |
| ) unnested_grp_cp |
| GROUP BY {grp_key_str} |
| ) grp_key_to_cp_list |
| JOIN |
| ( |
| SELECT |
| {grouping_cols}, |
| {grouping_array_str} as {grp_key_str} |
| FROM {training_table_name} |
| ) grp_key_to_grouping_cols |
| USING ({grp_key_str}) |
| """.format(**locals()), ["text[]", "float8[]"]) |
| plpy.execute(plan, [grp_list, cp_list]) |
| |
| # 2) call CV function to actually cross-validate _build_tree |
| # expects output table model_cv({grouping_cols), cp, avg, stddev) |
| model_cv = output_table_name + "_cv" |
| metric_function = "_tree_misclassified" if is_classification else "_tree_rmse" |
| pred_name = '"estimated_{0}"'.format(dependent_variable.strip(' "')) |
| grouping_str = 'NULL' if not grouping_cols else '"' + grouping_cols + '"' |
| |
| all_features = [cat_features, ordered_cat_features, boolean_cats, con_features] |
| |
| # _get_xvalidate_params builds the parameters used in |
| # DT train, predict, distance functions. Single quotes are added in these |
| # parameters (except for the feature arrays) since we run |
| # cross_validation_grouping_w_params with `add_param_quotes=False'. |
| # This special handling is put in place to ensure the feature arrays are |
| # treated as arrays instead of strings. |
| xvalidate_params = _get_xvalidate_params(**locals()) |
| cross_validation_grouping_w_params( |
| schema_madlib, |
| schema_madlib + '.__build_tree', |
| xvalidate_params[0], |
| xvalidate_params[1], |
| schema_madlib + '.__tree_predict', |
| xvalidate_params[2], |
| xvalidate_params[3], |
| schema_madlib + "." + metric_function, |
| xvalidate_params[4], |
| xvalidate_params[5], |
| group_to_param_list_table, param_list_name, grouping_cols, |
| training_table_name, id_col_name, False, |
| model_cv, 'cp', None, n_folds, |
| add_param_quotes=False) |
| |
| # 3) find the best cp for each group from table {model_cv} |
| if not grouping_cols: |
| grouping_array_str_comma = "''::TEXT AS grp_key," |
| group_by_str = '' |
| qualified_group_by_str = '' |
| group_select_str = '' |
| else: |
| grouping_array_qualified_str = get_grouping_array_str(training_table_name, |
| grouping_cols, |
| qualifier="min_cv_error") |
| grouping_array_str_comma = grouping_array_qualified_str + " AS grp_key," |
| grouping_cols_list = grouping_cols.split(",") |
| qualified_group_by_str = "GROUP BY " + ','.join("min_cv_error." + i for i in grouping_cols_list) |
| group_by_str = "GROUP BY " + grouping_cols |
| group_select_str = grouping_cols + "," |
| |
| plpy.notice(str(list(plpy.execute("SELECT * FROM {0}".format(model_cv))))) |
| validation_result_query = """ |
| SELECT |
| {grouping_array_str_comma} |
| max(cp) AS cp |
| FROM |
| {model_cv} |
| NATURAL JOIN |
| ( SELECT |
| {group_select_str} |
| min(cv_error_avg) AS cv_error_avg |
| FROM {model_cv} |
| {group_by_str} |
| ) min_cv_error |
| {qualified_group_by_str} |
| """.format(**locals()) |
| plpy.notice("validation_result_query:") |
| plpy.notice(validation_result_query) |
| validation_result = plpy.execute(validation_result_query) |
| plpy.notice("finished validation_result_query, validation_result = " + str(list(validation_result))) |
| |
| grp_key_to_best_cp = dict((row['grp_key'], row['cp']) for row in validation_result) |
| |
| # 4) update tree_states to have the best cp cross-validated |
| for tree in tree_states: |
| best_cp = grp_key_to_best_cp[tree['grp_key']] |
| if best_cp > tree['cp']: |
| tree['cp'] = best_cp |
| # we prune each tree further using a higher cp value |
| # giving the optimal pruned tree. |
| # This time we don't need the cp_list. |
| pruned_tree = _prune_and_cplist(schema_madlib, |
| tree, |
| tree['cp'], |
| compute_cp_list=False) |
| tree['tree_state'] = pruned_tree['tree_state'] |
| if 'pruned_depth' in pruned_tree: |
| tree['pruned_depth'] = pruned_tree['pruned_depth'] |
| elif 'tree_depth' in pruned_tree: |
| tree['pruned_depth'] = pruned_tree['tree_depth'] |
| else: |
| tree['pruned_depth'] = 0 |
| importance_vectors = _compute_var_importance( |
| schema_madlib, tree, |
| len(cat_features), len(con_features)) |
| tree.update(**importance_vectors) |
| |
| plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals())) |
| # ------------------------------------------------------------ |
| |
| |
| def _get_xvalidate_params(**kwargs): |
| """ Build train, predict, and metric parameters for cross_validation |
| |
| Args: |
| @param all_features |
| |
| Returns: |
| |
| """ |
| def _list_to_string_to_array(array_input): |
| """ Return a string that can interpreted by postgresql as text[] containing |
| the names in array_input |
| |
| Example: |
| Input: ['"Cont_features"[1]', '"Cont_features"[2]'] |
| Output: string_to_array('"Cont_features"[1]~^~"Cont_features"[2]'::text, '~^~'); |
| |
| When this output is executed by Postgresql it creates a text array: |
| madlib=# select string_to_array('"Cont_features"[1]~^~"Cont_features"[2]'::text, '~^~')::VARCHAR[] as t; |
| t |
| ------------------------------------------------- |
| {"\"Cont_features\"[1]","\"Cont_features\"[2]"} |
| (1 row) |
| """ |
| if not array_input: |
| return "'{}'" |
| return "string_to_array('{0}', '~^~')".format('~^~'.join(array_input)) |
| |
| all_feature_str = [_list_to_string_to_array(i) for i in kwargs['all_features']] |
| |
| def _add_quote(s): |
| if s is None: |
| return None |
| s = str(s) |
| return "NULL" if s.upper() == 'NULL' else "'{0}'".format(s) |
| |
| quoted_args = {} |
| for k, v in kwargs.items(): |
| quoted_args[k] = _add_quote(v) |
| |
| modeling_params = [quoted_args['is_classification'], |
| quoted_args['split_criterion'], |
| "%data%", |
| "%model%", |
| quoted_args['id_col_name'], |
| quoted_args['dependent_variable'], |
| quoted_args['dep_is_bool'], |
| quoted_args['list_of_features'], |
| all_feature_str[0], |
| all_feature_str[1], |
| all_feature_str[2], |
| all_feature_str[3], |
| quoted_args['grouping_str'], |
| quoted_args['weights'], |
| quoted_args['max_depth'], |
| quoted_args['min_split'], |
| quoted_args['min_bucket'], |
| quoted_args['n_bins'], |
| "%explore%", |
| quoted_args['max_n_surr'], |
| quoted_args['msg_level'], |
| quoted_args['null_proxy'] |
| ] |
| modeling_param_types = (["BOOLEAN"] + ["TEXT"] * 5 + ["BOOLEAN"] + ["TEXT"] + |
| ["VARCHAR[]"] * 4 + ["TEXT"] * 2 + ["INTEGER"] * 4 + |
| ["TEXT", "SMALLINT", "TEXT", "TEXT"]) |
| predict_params = ["%model%", "%data%", "%prediction%", "'response'", "True"] |
| predict_param_types = (["VARCHAR"] * 4 + ["BOOLEAN"]) |
| metric_params = ["%data%", |
| quoted_args['dependent_variable'], |
| "%prediction%", |
| quoted_args['pred_name'], |
| quoted_args['id_col_name'], |
| quoted_args['grouping_cols'], |
| "%error%", |
| "True"] |
| metric_param_types = ["VARCHAR", "VARCHAR", "VARCHAR", "VARCHAR", |
| "VARCHAR", "TEXT", "VARCHAR", "BOOLEAN"] |
| return [modeling_params, modeling_param_types, |
| predict_params, predict_param_types, |
| metric_params, metric_param_types] |
| # ---------------------------------------------------------------------- |
| |
| |
| def _tree_train_using_bins( |
| schema_madlib, bins, training_table_name, |
| cat_features, con_features, boolean_cats, n_bins, weights, |
| dep_var_str, min_split, min_bucket, max_depth, filter_dep, |
| dep_n_levels, is_classification, split_criterion, |
| subsample=False, n_random_features=1, max_n_surr=0, null_proxy=None, |
| **kwargs): |
| """Trains a tree without grouping columns""" |
| # Iterations for training the tree |
| tree_state = plpy.execute( |
| """ |
| SELECT {schema_madlib}._initialize_decision_tree( |
| {is_regression_tree}, |
| '{split_criterion}'::text, |
| {dep_n_levels}::smallint, |
| {max_n_surr}::smallint |
| ) AS tree_state, |
| FALSE as finished |
| """.format(schema_madlib=schema_madlib, |
| is_regression_tree=(not is_classification), |
| split_criterion=split_criterion, |
| dep_n_levels=dep_n_levels, |
| max_n_surr=max_n_surr))[0] |
| plpy.notice("Starting tree building") |
| tree_depth = -1 |
| while tree_state['finished'] == 0: |
| # finished: 0 = running, 1 = finished training, 2 = terminated prematurely |
| tree_depth += 1 |
| tree_state = _one_step( |
| schema_madlib, training_table_name, |
| cat_features, con_features, boolean_cats, bins, |
| n_bins, tree_state, weights, dep_var_str, |
| min_split, min_bucket, max_depth, filter_dep, |
| dep_n_levels, subsample, n_random_features, max_n_surr, null_proxy) |
| plpy.notice("Completed training of level {0}".format(tree_depth)) |
| |
| return tree_state |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _tree_train_grps_using_bins( |
| schema_madlib, bins, training_table_name, |
| cat_features, con_features, boolean_cats, n_bins, weights, |
| grouping_cols, grouping_array_str, |
| dep_var_str, min_split, min_bucket, max_depth, filter_dep, |
| dep_n_levels, is_classification, split_criterion, |
| cat_features_info_table, |
| subsample=False, n_random_features=1, tree_terminated=None, |
| max_n_surr=0, null_proxy=None, |
| **kwargs): |
| |
| """Trains a tree with grouping columns included """ |
| # Iterations for training the tree |
| initialized_tree_state = plpy.execute( |
| """ |
| SELECT |
| {schema_madlib}._initialize_decision_tree( |
| {is_regression_tree}, |
| '{split_criterion}'::text, |
| {dep_n_levels}::smallint, |
| {max_n_surr}::smallint) AS tree_state, |
| 0 AS finished |
| """.format(schema_madlib=schema_madlib, |
| is_regression_tree=not is_classification, |
| split_criterion=split_criterion, |
| dep_n_levels=dep_n_levels, |
| max_n_surr=max_n_surr))[0] |
| tree_states = [] |
| for key in bins['grp_key_con']: |
| group_copy = dict(initialized_tree_state, grp_key=key) |
| tree_states.append(group_copy) |
| |
| # The following is only used in random forest: |
| # If a group already has a terminated tree (not finished |
| # properly), then we do not continue to compute more trees |
| # for that specific group of data. |
| if tree_terminated is not None: |
| for item in tree_states: |
| if item['grp_key'] in tree_terminated and \ |
| tree_terminated[item['grp_key']] == 2: # terminated |
| item['finished'] = 2 # won't continue in _one_step_for_grps |
| |
| plpy.notice("Started tree building for all groups") |
| level = 0 |
| while not all(t['finished'] for t in tree_states): |
| tree_states = _one_step_for_grps( |
| schema_madlib, training_table_name, cat_features, |
| con_features, boolean_cats, bins, n_bins, |
| tree_states, weights, grouping_cols, |
| grouping_array_str, dep_var_str, min_split, min_bucket, |
| max_depth, filter_dep, dep_n_levels, subsample, |
| n_random_features, cat_features_info_table, |
| max_n_surr, null_proxy) |
| level += 1 |
| plpy.notice("Finished training for level " + str(level)) |
| |
| return tree_states |
| # ------------------------------------------------------------ |
| |
| |
| def _tree_error(schema_madlib, source_table, dependent_varname, |
| prediction_table, pred_dep_name, id_col_name, grouping_cols, |
| output_table, is_classification, |
| use_existing_tables=False, k=0, **kwargs): |
| with MinWarning("warning"): |
| if use_existing_tables and table_exists(output_table): |
| # plpy.execute("truncate " + output_table) |
| header = "INSERT INTO " + output_table + " " |
| else: |
| header = "CREATE TABLE " + output_table + " AS " |
| if is_classification: |
| error_func = """ |
| 1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} = |
| {source_table}.{dependent_varname}) |
| THEN 0 |
| ELSE 1 |
| END) / count(*) |
| """.format(**locals()) |
| else: |
| error_func = """ |
| sqrt(avg(({prediction_table}.{pred_dep_name} - |
| {source_table}.{dependent_varname})^2 |
| ) |
| ) |
| """.format(**locals()) |
| grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols |
| grouping_col_str = '' if not grouping_cols else grouping_cols + ',' |
| |
| sql = header + """ |
| SELECT |
| {grouping_col_str} |
| {error_func} as cv_error, |
| {k} as k |
| FROM {prediction_table}, {source_table} |
| WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name} |
| {grouping_str} |
| """.format(**locals()) |
| plpy.execute(sql) |
| # ------------------------------------------------------------ |
| |
| def tree_train_help_message(schema_madlib, message, **kwargs): |
| """ Help message for Decision Tree |
| """ |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Decision Tree |
| |
| Decision trees use a tree-based predictive model to |
| predict the value of a target variable based on several input variables. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.tree_train('usage'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.tree_train( |
| 'training_table', -- Data table name |
| 'output_table', -- Table name to store the tree model |
| 'id_col_name', -- Row ID, used in tree_predict |
| 'dependent_variable', -- The column to fit |
| 'list_of_features', -- Comma separated column names to be |
| used as the predictors, can be '*' |
| to include all columns except the |
| dependent_variable |
| 'features_to_exclude', -- Comma separated column names to be |
| excluded if list_of_features is '*' |
| 'split_criterion', -- How to split a node, options are |
| 'gini', 'misclassification' and |
| 'entropy' for classification, and |
| 'mse' for regression. |
| 'grouping_cols', -- Comma separated column names used to |
| group the data. A decision tree model |
| will be created for each group. Default |
| is NULL |
| 'weights', -- A Column name containing weights for |
| each observation. Default is NULL |
| max_depth, -- Maximum depth of any node, default is 7 |
| min_split, -- Minimum number of observations that must |
| exist in a node for a split to be |
| attemped, default is 20 |
| min_bucket, -- Minimum number of observations in any |
| terminal node, default is min_split/3 |
| n_bins, -- Number of bins to find possible node |
| split threshold values for continuous |
| variables, default is 20 (Must be greater than 1) |
| pruning_params, -- A comma-separated text containing |
| key=value pairs of parameters for pruning. |
| Parameters accepted: |
| 'cp' - complexity parameter with default=0.01, |
| 'n_folds' - number of cross-validation folds |
| with default value of 0 (= no cross-validation) |
| null_handling_params, -- A comma-separated text containing |
| key=value pairs of parameters for handling NULL values. |
| Parameters accepted: |
| 'max_surrogates' - Maximum number of surrogates to |
| compute for each split |
| 'null_as_category' - Boolean to indicate if |
| NULL should be treated as a special category |
| verbose -- Boolean, whether to print more info, default is False |
| ); |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the following columns (quoted items |
| are of type TEXT): |
| <grouping columns> -- Grouping columns, only present when |
| 'grouping_cols' is not NULL or '' |
| tree -- The decision tree model as a binary string |
| cat_levels_in_text -- Distinct levels (casted to text) of all |
| categorical variables combined in a single array |
| cat_n_levels -- Number of distinct levels of all categorical variables |
| tree_depth -- Number of levels in the tree (root has level 0) |
| pruning_cp -- The cost-complexity parameter used for pruning |
| the trained tree(s). This would be different |
| from the input cp value if cross-validation is used. |
| |
| The output summary table ('output_table_summary') has the following columns: |
| 'method' -- Method name: 'tree_train' |
| 'source_table' -- Data table name |
| 'model_table' -- Tree model table name |
| 'id_col_name' -- Name of the 'id' column |
| is_classification -- Boolean value indicating if tree is classification or regression |
| 'dependent_varname' -- Response variable column name |
| 'independent_varnames' -- Comma-separated feature column names |
| 'cat_features' -- Comma-separated column names of categorical variables |
| 'con_features' -- Comma-separated column names of continuous variables |
| 'grouping_cols' -- Grouping column names |
| num_all_groups -- Number of groups |
| num_failed_groups -- Number of groups for which training failed |
| total_rows_processed -- Number of rows used in the model training |
| total_rows_skipped -- Number of rows skipped because NULL values |
| dependent_var_levels -- For classification, the distinct levels of |
| the dependent variable |
| dependent_var_type -- The type of dependent variable |
| input_cp -- The complexity parameter (cp) used for pruning the |
| trained tree(s) (before cross-validation is run) |
| independent_var_types -- The types of independent variables, comma-separated |
| k -- Number of folds (NULL if not using cross validation) |
| null_proxy -- String used as replacement for NULL values |
| (NULL if null_as_category = False) |
| |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.tree_train('usage')" |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |
| |
| |
| def tree_predict_help_message(schema_madlib, message, **kwargs): |
| """ Help message for Decision Tree predict |
| """ |
| if not message: |
| help_string = """ |
| ------------------------------------------------------------ |
| SUMMARY |
| ------------------------------------------------------------ |
| Functionality: Decision Tree Prediction |
| |
| Prediction for a decision tree (trained using {schema_madlib}.tree_train) can |
| be performed on a new data table. |
| |
| For more details on the function usage: |
| SELECT {schema_madlib}.tree_predict('usage'); |
| For an example on using this function: |
| SELECT {schema_madlib}.tree_predict('example'); |
| """ |
| elif message.lower().strip() in ['usage', 'help', '?']: |
| help_string = """ |
| ------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------ |
| SELECT {schema_madlib}.tree_predict( |
| 'tree_model', -- Model table name (output of tree_train) |
| 'new_data_table', -- Prediction source table |
| 'output_table', -- Table name to store the predictions |
| 'type' -- Type of prediction output |
| ); |
| |
| Note: The 'new_data_table' should have the same 'id_col_name' column as used |
| in the training function. This is used to corelate the prediction data row with |
| the actual prediction in the output table. |
| |
| ------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------ |
| The output table ('output_table' above) has the '<id_col_name>' column giving |
| the 'id' for each prediction and the prediction columns for the response |
| variable (also called as dependent variable). |
| |
| If prediction type = 'response', then the table has a single column with the |
| prediction value of the response. The type of this column depends on the type |
| of the response variable used during training. |
| |
| If prediction type = 'prob', then the table has multiple columns, one for each |
| possible value of the response variable. The columns are labeled as |
| 'estimated_prob_<dep value>', where <dep value> represents for each value |
| of the response. |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.tree_predict('usage')" |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |