blob: 69a54f30d551bd2985a58d3353c71d9b0b16ced0 [file] [log] [blame]
# coding=utf-8
"""
@file decision_tree.py_in
@brief Decision Tree: Driver functions
@namespace decision_tree
"""
from __future__ import division
import plpy
from math import sqrt
from operator import itemgetter
from itertools import groupby
from collections import Iterable
from internal.db_utils import quote_literal
from utilities.control import MinWarning
from utilities.control import OptimizerControl
from utilities.control import HashaggControl
from utilities.utilities import _assert
from utilities.utilities import _array_to_string
from utilities.utilities import _check_groups
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type
from utilities.utilities import py_list_to_sql_string
from utilities.utilities import split_quoted_delimited_str
from utilities.utilities import unique_string
from utilities.validate_args import _get_table_schema_names
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import explicit_bool_to_text
from utilities.validate_args import get_cols
from utilities.validate_args import get_cols_and_types
from utilities.validate_args import get_expr_type
from utilities.validate_args import is_var_valid
from utilities.validate_args import table_is_empty
from utilities.validate_args import table_exists
from utilities.validate_args import unquote_ident
from validation.cross_validation import cross_validation_grouping_w_params
# ------------------------------------------------------------
def _tree_validate_args(
split_criterion, training_table_name, output_table_name,
id_col_name, list_of_features, dependent_variable,
list_of_features_to_exclude, grouping_cols, weights, max_depth,
min_split, min_bucket, n_bins, cp, n_folds, **kwargs):
""" Validate the arguments
"""
_assert(training_table_name and
training_table_name.strip().lower() not in ('null', ''),
"Decision tree error: Invalid data table.")
_assert(table_exists(training_table_name),
"Decision tree error: Data table is missing.")
_assert(not table_exists(output_table_name, only_first_schema=True),
"Decision tree error: Output table already exists.")
_assert(not table_exists(add_postfix(output_table_name, "_summary"), only_first_schema=True),
"Decision tree error: Output summary table already exists.")
_assert(columns_exist_in_table(training_table_name, [id_col_name]),
"Decision tree error: ID column does not exist.")
_assert(not (dependent_variable is None or dependent_variable.strip().lower() == ''),
"Decision tree error: Dependent variable is empty.")
_assert(is_var_valid(training_table_name, dependent_variable),
"Decision tree error: Invalid dependent variable ({0}).".
format(dependent_variable))
_assert(list_of_features and list_of_features.strip(),
"Decision tree error: Features to include is empty.")
if list_of_features.strip() != '*':
_assert(is_var_valid(training_table_name, list_of_features),
"Decision tree error: Invalid feature list ({0})".
format(list_of_features))
if grouping_cols:
_assert(is_var_valid(training_table_name, grouping_cols),
"Decision tree error: Invalid grouping column argument.")
if weights is not None and weights.strip() != '':
_assert(is_var_valid(training_table_name, weights),
"Decision tree error: Invalid weights argument.")
_assert(max_depth >= 0 and max_depth < 100,
"Decision tree error: maximum tree depth must be positive and less than 100.")
_assert(not table_is_empty(training_table_name,
_get_filter_str(dependent_variable, grouping_cols)),
"Decision tree error: Data table ({0}) is empty "
"(after filtering invalid tuples)".
format(training_table_name))
_assert(cp >= 0, "Decision tree error: cp must be non-negative.")
_assert(min_split > 0, "Decision tree error: min_split must be positive.")
_assert(min_bucket > 0, "Decision tree error: min_bucket must be positive.")
_assert(n_bins > 1, "Decision tree error: number of bins must be at least 2.")
_assert(n_folds >= 0, "Decision tree error: number of cross-validation "
"folds must be non-negative.")
# ------------------------------------------------------------
def _validate_split_criterion(split_criterion, is_classification):
_assert(split_criterion.lower().strip() in ['mse', 'gini', 'cross-entropy',
'entropy', 'misclass',
'misclassification'],
"Decision tree error: Invalid split_criterion.")
if is_classification:
if split_criterion.lower().strip() == "mse":
plpy.error("Decision tree error: MSE is not a valid "
"split criterion for classification.")
else:
if split_criterion.lower().strip() != "mse":
plpy.warning("Decision tree: Using MSE as split criterion as it "
"is the only one supported for regression trees.")
split_criterion = "mse"
return split_criterion
# ------------------------------------------------------------------------------
def _get_features_to_use(schema_madlib, training_table_name,
list_of_features, list_of_features_to_exclude,
id_col_name, weights, dependent_variable,
grouping_cols=None, **kwargs):
""" Expand '*' syntax and exclude some features
Ignore 'list_of_features_to_exclude' if 'list_of_features' is not '*'
We also exclude from the features all grouping_cols, id_col_name, weights and dependent_variable
"""
# for some of the sets below we include the quoted name and the unquoted name
# to allow user to provide either form. Both forms are added in the exclusion
# list.
if grouping_cols:
group_set = set(split_quoted_delimited_str(grouping_cols))
group_set |= set(unquote_ident(i)
for i in split_quoted_delimited_str(grouping_cols))
else:
group_set = set()
other_col_set = set([id_col_name, weights, dependent_variable])
other_col_set |= set(unquote_ident(i)
for i in [id_col_name, weights, dependent_variable])
if list_of_features.strip() == '*':
all_col_set = set(get_cols(training_table_name))
exclude_set = set(split_quoted_delimited_str(list_of_features_to_exclude))
feature_set = all_col_set - exclude_set
filtered_feature_list = list(feature_set - group_set - other_col_set)
else:
feature_list = split_quoted_delimited_str(list_of_features)
feature_exclude = split_quoted_delimited_str(list_of_features_to_exclude)
return_set = set(feature_list) - set(feature_exclude) - group_set - other_col_set
# instead of returning list(return_set) we create a list that has
# elements in same order as original feature_list
filtered_feature_list = [feat for feat in feature_list if feat in return_set]
# check if any of the features is an array and expand the array
final_feature_list = []
for feat in filtered_feature_list:
feat_type = get_expr_type(feat, training_table_name)
if '[]' in feat_type:
# expand array by indexing into it
feat_dims = plpy.execute("""
SELECT array_lower({f}, 1) as l,
array_upper({f}, 1) as u
FROM {tbl}
LIMIT 1
""".format(f=feat, tbl=training_table_name))[0]
final_feature_list += ["({f})[{i}]".format(f=feat, i=i)
for i in range(feat_dims['l'], feat_dims['u'] + 1)]
else:
final_feature_list.append(feat)
return final_feature_list
# ------------------------------------------------------------
def _classify_features(feature_to_type, features):
""" Returns
1) an array of categorical features
2) an array of ordered categorical features
3) an array of boolean categorical variables
4) an array of continuous features
"""
# any column belonging to the following types are categorical
int_types = ['integer', 'smallint', 'bigint']
text_types = ['text', 'varchar', 'character varying', 'char', 'character']
boolean_types = ['boolean']
cat_types = int_types + text_types + boolean_types
ordered_cat_types = int_types
cat_features = [c for c in features if feature_to_type[c] in cat_types]
ordered_cat_features = [c for c in features
if feature_to_type[c] in ordered_cat_types]
# In order to be able to form an array, all categorical variables
# will be cast into TEXT type, but GPDB cannot cast a boolean
# directly into a text. Thus, boolean categorical variables
# need special treatment: cast them into integers before casting
# into text.
boolean_cat_features = [c for c in features
if feature_to_type[c] in boolean_types]
# Integer types are not considered continuous
con_features = [c for c in features
if is_psql_numeric_type(feature_to_type[c], exclude=int_types)]
return cat_features, ordered_cat_features, boolean_cat_features, con_features
# ------------------------------------------------------------
def _extract_pruning_params(pruning_params_str):
"""
Args:
@param pruning_param: str, Parameters used for pruning the tree
cp = Cost-complexity for pruning
n_folds = Number of folds for cross-validation
Returns:
dict. A dictionary containing the pruning parameters
"""
default_dict = dict(cp=0, n_folds=0)
params_types = dict(cp=float, n_folds=int)
pruning_params = extract_keyvalue_params(pruning_params_str,
params_types,
default_dict)
if pruning_params['n_folds'] < 0:
plpy.error("Decision Tree error: Number of cross-validation folds ({0}) "
"must be non-negative".format(pruning_params['n_folds']))
return pruning_params
# ------------------------------------------------------------
def _get_tree_states(schema_madlib, is_classification, split_criterion,
training_table_name, output_table_name, id_col_name,
dependent_variable, dep_is_bool,
grouping_cols, cat_features, ordered_cat_features,
con_features, n_bins, boolean_cats,
min_split, min_bucket, weights,
max_depth, grp_key_to_cp, compute_cp_list=False,
max_n_surr=0, null_proxy=None, **kwargs):
"""
Args:
grp_key_to_cp : Dictionary, mapping from group key to the cp value
for each group. For the no grouping case, the
key is ''
"""
filter_dep = _get_filter_str(dependent_variable, grouping_cols)
if is_classification:
# For classifications, we also need to map dependent_variable to integers
n_rows, dep_list = _get_n_and_deplist(training_table_name,
dependent_variable,
filter_dep)
if dep_is_bool:
# false = 0, true = 1
# This order is maintained in dep_list since
# _get_n_and_deplist returns a sorted list
dep_var_str = ("(CASE WHEN {0} THEN 1 ELSE 0 END)".
format(dependent_variable))
else:
dep_var_str = ("(CASE " +
"\n\t\t".join(["WHEN ({0})::text = $${1}$$ THEN {2}".
format(dependent_variable, str(c), i)
for i, c in enumerate(dep_list)]) +
"\nEND)")
else:
n_rows = long(plpy.execute("SELECT count(*)::bigint "
"FROM {src} "
"WHERE {filter}".
format(src=training_table_name,
filter=filter_dep))[0]['count'])
dep_var_str = dependent_variable
dep_list = []
dep_n_levels = len(dep_list) if dep_list else 1
cat_features_info_table = unique_string()
if not grouping_cols: # non-grouping case
# 3) Find the splitting bins, one dict containing two arrays:
# categorical bins and continuous bins
bins = _get_bins(schema_madlib, training_table_name, cat_features,
ordered_cat_features, con_features, n_bins,
dep_var_str, boolean_cats, n_rows, is_classification,
dep_n_levels, filter_dep, null_proxy)
# some features may be dropped if they have only one value
cat_features = bins['cat_features']
if not cat_features and not con_features:
plpy.error("Decision tree: None of the input features are valid")
_create_cat_features_info_table(cat_features_info_table, bins)
# 4) Run tree train till the training is finished
# finished: 0 = running, 1 = finished training, 2 = terminated prematurely
tree = _tree_train_using_bins(**locals())
tree['grp_key'] = ''
tree['cp'] = grp_key_to_cp[tree['grp_key']]
tree_states = [tree]
else:
grouping_array_str = get_grouping_array_str(training_table_name, grouping_cols)
with OptimizerControl(False):
# we disable optimizer (ORCA) for platforms that use it
# since ORCA doesn't provide an easy way to disable hashagg
with HashaggControl(False):
# we disable hashagg since large number of groups could
# result in excessive memory usage.
plpy.notice("Analyzing data to compute split boundaries for variables")
bins = _get_bins_grps(schema_madlib, training_table_name,
cat_features, ordered_cat_features,
con_features, n_bins,
dep_var_str,
boolean_cats, grouping_cols,
grouping_array_str, n_rows,
is_classification, dep_n_levels,
filter_dep, null_proxy)
cat_features = bins['cat_features']
if not cat_features and not con_features:
plpy.error("Decision tree: None of the input features "
"are valid for some groups")
_create_cat_features_info_table(cat_features_info_table, bins)
# 3b) Load each group's tree state in memory and set to the initial tree
tree_states = _tree_train_grps_using_bins(**locals())
for tree in tree_states:
grp_key = tree['grp_key']
if len(grp_key_to_cp.values()) == 1:
# for train w/out CV, the cp value remains the same for
# all groups. This is passed as a single-element list.
tree['cp'] = grp_key_to_cp.values()[0]
else:
tree['cp'] = grp_key_to_cp[grp_key]
# 5) prune the tree using provided 'cp' value and produce a list of
# cp values if cross-validation is required (cp_list = [] if not)
for tree in tree_states:
if 'cp' in tree:
pruned_tree = _prune_and_cplist(schema_madlib, tree,
tree['cp'],
compute_cp_list=compute_cp_list)
tree['tree_state'] = pruned_tree['tree_state']
if 'pruned_depth' in pruned_tree:
tree['pruned_depth'] = pruned_tree['pruned_depth']
else:
tree['pruned_depth'] = pruned_tree['tree_depth']
if 'cp_list' in pruned_tree:
tree['cp_list'] = pruned_tree['cp_list']
importance_vectors = _compute_var_importance(
schema_madlib, tree,
len(cat_features), len(con_features))
tree.update(**importance_vectors)
return tree_states, bins, dep_list, n_rows, cat_features_info_table
# -------------------------------------------------------------------------
def get_grouping_array_str(table_name, grouping_cols, qualifier=None):
"""
Args:
@param grouping_cols: list, List of columns used as grouping columns
"""
def _col_to_text(col, col_type):
qualifier_str = qualifier + "." if qualifier else ''
if is_psql_boolean_type(col_type):
return "(case when {0} then 'true' else 'false' end)::text".format(col)
else:
return '({0}{1})::text'.format(qualifier_str, col)
grouping_cols_list = [col.strip() for col in grouping_cols.split(',')]
grouping_cols_and_types = [(c, get_expr_type(c, table_name))
for c in grouping_cols_list]
grouping_array_str = "array_to_string(array[{0}]::text[], ',')".format(
','.join(_col_to_text(col, col_type) for col, col_type in grouping_cols_and_types))
return grouping_array_str
# ------------------------------------------------------------------------------
# XXX This function is used only in decision tree cross-validation.
# So it has some specific code for cv.
def _build_tree(schema_madlib, is_classification, split_criterion,
training_table_name, output_table_name, id_col_name,
dependent_variable, dep_is_bool, list_of_features,
cat_features, ordered_cat_features,
boolean_cats, con_features, grouping_cols,
weights, max_depth, min_split, min_bucket, n_bins,
cp_table, max_n_surr=0, null_proxy=None,
msg_level="warning", n_folds=0, **kwargs):
compute_cp_list = False
if grouping_cols:
grouping_array_str = get_grouping_array_str(cp_table, grouping_cols)
else:
grouping_array_str = "''::TEXT"
sql = """SELECT
{0} AS grp_key,
explore_value as cp_val
FROM {1}
""".format(grouping_array_str, cp_table)
grp_cp_values = plpy.execute(sql)
grp_key_to_cp = dict((row['grp_key'], row['cp_val']) for row in grp_cp_values)
with MinWarning(msg_level):
plpy.notice("Building tree for cross validation")
tree_states, bins, dep_list, n_rows, cat_features_info_table = \
_get_tree_states(**locals())
all_cols_types = dict([(f, get_expr_type(f, training_table_name))
for f in cat_features + con_features])
n_all_rows = plpy.execute("select count(*) from " + training_table_name
)[0]['count']
cp = grp_key_to_cp.values()[0]
# _create_output_tables(...) is a general function
# we need to let it know that right now it is called from CV.
use_existing_tables = table_exists(output_table_name) # create tables if it does not exist
running_cv = True # flag to indicate that cv fold ID needs to be included
_create_output_tables(list_of_features_to_exclude='',
**locals())
# ------------------------------------------------------------------------------
def tree_train(schema_madlib, training_table_name, output_table_name,
id_col_name, dependent_variable, list_of_features,
list_of_features_to_exclude, split_criterion,
grouping_cols, weights, max_depth,
min_split, min_bucket, n_bins, pruning_params,
null_handling_params, verbose_mode, **kwargs):
""" Decision tree main training function
Args:
@param schema_madlib: str, MADlib schema name
@param training_table_name: str, source table name
@param output_table_name: str, model table name
@param id_col_name: str, id column name to uniquely identify each row
@param dependent_variable: str, dependent variable column name
@param list_of_features: str, Comma-separated list of feature column names,
can also be '*' implying all columns
except dependent_variable
@param list_of_features_to_exclude: str, features to exclude if '*' is used
@param split_criterion: str, Impurity function to use for splitting:
Classification: {'gini', 'entropy', 'misclass'}
Regression: 'mse'
@param grouping_cols: str, List of grouping columns to group the data
@param weights: str, Column name for weight for each tuple
@param max_depth: int, Maximum depth of the tree
@param min_split: int, Minimum tuples in a node before splitting it
@param min_bucket: int, Minimum tuples in each child before splitting a node
@param n_bins: int, Number of bins for quantizing a continuous variables
@param pruning_params: str, Parameters used for pruning the tree
cp = Cost-complexity for pruning
n_folds = Number of folds for cross-validation
@param verbose_mode: str, Verbosity of output messages
"""
msg_level = "notice" if verbose_mode else "warning"
# Set default values for all arguments
split_criterion = 'gini' if not split_criterion else split_criterion
max_depth = 7 if max_depth is None else max_depth
if min_split is None and min_bucket is None:
min_split = 20
min_bucket = 6
else:
min_bucket = min_split // 3 if min_bucket is None else min_bucket
min_split = min_bucket * 3 if min_split is None else min_split
n_bins = 20 if n_bins is None else n_bins
# defaults for cp and n_folds set within _extract_pruning_params
pruning_param_dict = _extract_pruning_params(pruning_params)
cp = pruning_param_dict['cp']
n_folds = pruning_param_dict['n_folds']
# null handing parameters: max_n_surr and null_as_category
null_handling_dict = extract_keyvalue_params(
null_handling_params,
dict(max_surrogates=int, null_as_category=bool),
dict(max_surrogates=0, null_as_category=False))
max_n_surr = null_handling_dict['max_surrogates']
null_as_category = null_handling_dict['null_as_category']
null_proxy = "__NULL__" if null_as_category else None
if null_as_category:
# can't have two ways of handling tuples with NULL values
max_n_surr = 0
_assert(max_n_surr >= 0,
"Maximum number of surrogates ({0}) should be non-negative".
format(max_n_surr))
with MinWarning(msg_level):
# 1)
if not grouping_cols or not grouping_cols.strip():
grouping_cols = ''
_tree_validate_args(**locals())
# weights column has be to validated to be present,
# hence default value allocation happens after the validation
weights = '1' if not weights or not weights.strip() else weights.strip()
# expand "*" syntax and exclude some features
features = _get_features_to_use(**locals())
_assert(bool(features),
"Decision tree error: No feature is selected for the model.")
# 2)
all_cols_types = dict([(f, get_expr_type(f, training_table_name))
for f in features])
cat_features, ordered_cat_features, boolean_cats, con_features = \
_classify_features(all_cols_types, features)
# assert that the continuous and categorical features together
# cover all features
invalid_features = set(features) - (set(cat_features) | set(con_features))
_assert(not invalid_features,
"DT error: Some of the features are invalid ({0})".
format(invalid_features))
# get all rows
n_all_rows = plpy.execute("SELECT count(*) FROM {source_table}".
format(source_table=training_table_name)
)[0]['count']
is_classification, dep_is_bool = _is_dep_categorical(
training_table_name, dependent_variable)
split_criterion = _validate_split_criterion(split_criterion, is_classification)
# 4) Build the tree with provided cp value
compute_cp_list = (n_folds > 1)
grp_key_to_cp = {'': cp}
# main training function to get trained decision trees
plpy.notice("Getting initial tree")
tree_states, bins, dep_list, n_rows, cat_features_info_table = \
_get_tree_states(**locals())
# 5) Perform cross-validation to compute the lowest cp
dep_n_levels = len(dep_list) if dep_list else 1
if n_folds > 1:
plpy.notice("Running cross validation")
_xvalidate(**locals())
plpy.notice("Creating output tables")
_create_output_tables(**locals())
return None
# ------------------------------------------------------------
def _create_output_tables(schema_madlib, training_table_name, output_table_name,
tree_states, bins, split_criterion,
id_col_name, dependent_variable, list_of_features,
list_of_features_to_exclude,
is_classification, n_all_rows, n_rows, dep_list, cp,
all_cols_types, cat_features_info_table,
grouping_cols=None,
use_existing_tables=False, running_cv=False,
n_folds=0, null_proxy=None, **kwargs):
if not grouping_cols:
_create_result_table(schema_madlib, tree_states[0],
bins['cat_origin'], bins['cat_n'], bins['cat_features'],
bins['con_features'], output_table_name,
use_existing_tables, running_cv, n_folds)
else:
_create_grp_result_table(
schema_madlib, tree_states, bins, bins['cat_features'],
bins['con_features'], output_table_name, cat_features_info_table,
grouping_cols, training_table_name, use_existing_tables,
running_cv, n_folds)
failed_groups = sum(row['finished'] != 1 for row in tree_states)
_create_summary_table(
schema_madlib, split_criterion, training_table_name,
output_table_name, id_col_name, bins['cat_features'], bins['con_features'],
dependent_variable, list_of_features, list_of_features_to_exclude,
failed_groups, is_classification, n_all_rows,
n_rows, dep_list, all_cols_types, cp, grouping_cols, 1,
use_existing_tables, n_folds, null_proxy)
# -------------------------------------------------------------------------
def _get_n_and_deplist(training_table_name, dependent_variable, filter_null):
"""
@brief Query the database for the total number of rows and
levels of dependent variable if the dependent variable is
categorical.
Note: The deplist is sorted in the array_agg which is necessary to ensure
false = 0 and true = 1 for boolean dependent variable.
"""
sql = """
SELECT
sum(n_rows) as n_rows,
array_agg(dep ORDER BY dep) as dep
FROM (
SELECT
count(*) as n_rows,
{dep_var} as dep
FROM
{source}
WHERE {filter_null}
GROUP BY ({dep_var})
) s
""".format(dep_var=dependent_variable,
source=training_table_name,
filter_null=filter_null)
r = plpy.execute(sql)[0]
r['n_rows'] = 0 if r['n_rows'] is None else r['n_rows']
return long(r['n_rows']), r['dep']
# ------------------------------------------------------------
def _is_dep_categorical(training_table_name, dependent_variable):
"""
@brief Sample the dependent variable to check whether it is
a categorical variable.
"""
sample_dep = get_expr_type(dependent_variable, training_table_name)
is_dep_numeric = is_psql_numeric_type(sample_dep,
exclude=['smallint',
'integer',
'bigint'])
is_dep_bool = is_psql_boolean_type(sample_dep)
return (not is_dep_numeric, is_dep_bool)
# ------------------------------------------------------------
def _get_bins(schema_madlib, training_table_name,
cat_features, ordered_cat_features,
con_features, n_bins, dependent_variable, boolean_cats, n_rows,
is_classification, dep_n_levels, filter_null_dep,
null_proxy=None):
""" Compute the bins of all features
@param training_table_name Data source table
@param cat_features A list of strings, categorical column names
@param con_features A list of strings, continuous column names
@param n_bins Number of splits equals n_bins - 1
@param dependent_variable Will be needed when sorting the levels of
categorical variables
@param boolean_cats The categorical variables that are of boolean type
return one dictionary containing two arrays: categorical and continuous
"""
if len(con_features) > 0:
if n_bins > n_rows:
plpy.error("Decision tree error: Number of bins is larger than "
"the number of data records.")
sample_size = n_bins * n_bins # We use Spark's value here
# FIXME Hard coded number
if sample_size < 10000:
sample_size = 10000
# Compute the percentage of the sample.
# We sample a few more values to make sure we can get enough
# samples, otherwise the number of samples might be smaller
# than sample_size.
# Use design doc Eq. (2.1)
actual_sample_size = sample_size + 14 + sqrt(196 + 28 * sample_size)
if actual_sample_size > n_rows:
actual_sample_size = n_rows
percentage = actual_sample_size / n_rows
# For continuous variables, use one function to compute
# the splits for all of them. Similar to the existing
# _compute_splits function in CoxPH module, but deal with
# multiple columns together.
con_features_str = py_list_to_sql_string(con_features, "double precision")
sample_table_name = unique_string()
plpy.execute("""
CREATE TEMP TABLE {sample_table_name} AS
SELECT *
FROM (
SELECT *, random() AS rand
FROM {training_table_name}
WHERE {filter_null_dep}
AND not {schema_madlib}.array_contains_null({con_features_str})
) subq
WHERE rand <= {percentage}
""".format(**locals()))
# The splits for continuous variables
con_split_str = ("""{schema_madlib}._dst_compute_con_splits(
{con_features_str},
{actual_sample_size}::integer,
{n_bins}::smallint)""".
format(**locals()))
con_splits = plpy.execute("""
SELECT {con_split_str} as con_splits
FROM {sample_table_name}
""".format(**locals()))[0]
plpy.execute("DROP TABLE {sample_table_name}".format(**locals()))
else:
con_splits = {'con_splits': ''} # no continuous features present
# For categorical variables, scan the whole table to extract all the
# levels of the categorical variables, and at the same time
# sort the levels according to the entropy of the dependent
# variable.
# So this aggregate returns a composite type with two columns:
# col 1 is the array of ordered levels; col 2 is the number of
# levels in col 1.
# TODO When n_bins is larger than 2^k - 1, where k is the number
# of levels of a given categrical feature, we can actually compute
# all combinations of levels and obtain a complete set of splits
# instead of using sorting to get an approximate set of splits.
#
# We will use integer to represent levels of categorical variables.
# So before everything, we need to create a mapping from categorical
# variable levels to integers, and keep this mapping in the memory.
if len(cat_features) > 0:
if is_classification:
# For classification the dependent variable is encoded as an integer
order_fun = ("{madlib}._dst_compute_entropy({dep}, {n})".
format(madlib=schema_madlib,
dep=dependent_variable,
n=dep_n_levels))
else:
# For regressions
order_fun = "AVG({0})".format(dependent_variable)
# Note that 'sql_cat_levels' goes through two levels of string formatting
# Try to obtain all the levels in one scan of the table.
# () are needed when casting the categorical variables because
# they can be expressions.
filter_str = filter_null_dep + " AND {col} IS NOT NULL"
if null_proxy is None:
union_null_proxy = ""
else:
union_null_proxy = "UNION " + """
SELECT '{null_proxy}'::text as levels,
'Infinity'::double precision as dep_avg
FROM {training_table_name}
WHERE {{col}} IS NULL
GROUP BY {{col}}
""".format(**locals())
sql_cat_levels = """
SELECT
'{{col_name}}'::text AS colname,
levels
FROM (
SELECT
'{{col_name}}'::text AS colname,
array_agg(levels ORDER BY dep_avg) AS levels
FROM (
SELECT
({{col}})::text AS levels,
{{order_fun}} AS dep_avg
FROM {training_table_name}
WHERE {filter_str}
GROUP BY {{col}}
{union_null_proxy}
) s
) s1
""".format(training_table_name=training_table_name,
filter_str=filter_str,
union_null_proxy=union_null_proxy)
all_col_expressions = dict(zip(cat_features, explicit_bool_to_text(
training_table_name, cat_features, schema_madlib)))
sql_all_cats = ' UNION '.join(
sql_cat_levels.format(
col=expr,
col_name=col_name,
order_fun=expr if col_name in ordered_cat_features else order_fun)
for col_name, expr in all_col_expressions.items())
all_levels = plpy.execute(sql_all_cats)
col_to_row = dict((row['colname'], i) for i, row in enumerate(all_levels))
return dict(
con=con_splits['con_splits'],
con_features=con_features,
cat_origin=[level for col in cat_features
for level in all_levels[col_to_row[col]]['levels']],
cat_n=[len(all_levels[col_to_row[col]]['levels'])
for col in cat_features],
cat_features=cat_features)
else:
# categorical part is empty
return dict(
con=con_splits['con_splits'],
con_features=con_features,
cat_origin=[],
cat_n=[],
cat_features=[])
# ------------------------------------------------------------
def _create_result_table(schema_madlib, tree_state,
cat_origin, cat_n, cat_features,
con_features, output_table_name,
use_existing_tables=False, running_cv=False,
k=0):
""" Create the output table and the summary table
In the result table, we need the tree_state and also the categorical
sorted levels, which will be used in the printing and the prediction
functions.
"""
fold = ", " + str(k) + " as k" if running_cv else ""
if use_existing_tables:
# plpy.execute("truncate " + output_table_name)
header = "insert into " + output_table_name + " "
else:
header = "create table " + output_table_name + " as "
depth = (tree_state['pruned_depth'] if 'pruned_depth' in tree_state
else tree_state['tree_depth'])
if len(cat_features) > 0:
sql = header + """
SELECT
{cp} AS pruning_cp,
$1 AS tree,
$2 AS cat_levels_in_text,
$3 AS cat_n_levels,
$4 AS impurity_var_importance,
{depth} AS tree_depth
{fold}
""".format(depth=depth,
cp=tree_state['cp'],
fold=fold)
sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
'text[]',
'integer[]',
'double precision[]'])
plpy.execute(sql_plan, [tree_state['tree_state'],
cat_origin,
cat_n,
tree_state['impurity_var_importance']])
else:
sql = header + """
SELECT
{cp} AS pruning_cp,
$1 AS tree,
NULL::text[] AS cat_levels_in_text,
NULL::integer[] AS cat_n_levels,
$2 AS impurity_var_importance,
{depth} AS tree_depth
{fold}
""".format(depth=depth,
cp=tree_state['cp'],
fold=fold)
sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
'double precision[]'])
plpy.execute(sql_plan, [tree_state['tree_state'],
tree_state['impurity_var_importance']])
# ------------------------------------------------------------
def _get_bins_grps(
schema_madlib, training_table_name, cat_features, ordered_cat_features,
con_features, n_bins, dependent_variable, boolean_cats,
grouping_cols, grouping_array_str, n_rows, is_classification,
dep_n_levels, filter_null_dep, null_proxy=None):
""" Compute the bins for all features in each group
@brief Similar to _get_bins except that this is for multiple groups.
So please refer to the comments inside _get_bins for more related
information.
returns what _one_step_for_grps needs:
"""
if len(con_features) > 0:
if n_bins > n_rows:
plpy.error("Decision tree error: Number of bins is larger than "
"the number of data records.")
sample_size = n_bins * n_bins # We use Spark's value here
# FIXME Hard coded number
if sample_size < 10000:
sample_size = 10000
# Compute the percentage of the sample.
# We sample a few more values to make sure we can get enough
# samples, otherwise the number of samples might be smaller
# than sample_size.
# Use design doc Eq. (2.1)
# FIXME Should also use this for the CoxPH module
sample_size = sample_size + 14 + sqrt(196 + 28 * sample_size)
grp_key_str = unique_string()
n_per_seg_str = unique_string()
grp_size_str = unique_string()
random_str = unique_string()
con_features_str = py_list_to_sql_string(con_features, "double precision")
sample_table_name = unique_string()
sql = """
CREATE TEMP TABLE {sample_table_name} AS
SELECT *
FROM
(
SELECT
*,
random() AS rand,
CASE WHEN {grp_size_str} < {sample_size}
THEN {grp_size_str}
ELSE {sample_size}
END AS {n_per_seg_str}
FROM
(
SELECT *
FROM {training_table_name}
WHERE {filter_null_dep}
AND not {schema_madlib}.array_contains_null({con_features_str})
) src
JOIN
(
SELECT {grouping_cols}, count(*) AS {grp_size_str}
FROM {training_table_name}
WHERE {filter_null_dep}
AND not {schema_madlib}.array_contains_null({con_features_str})
GROUP BY {grouping_cols}
) grp_info
USING ({grouping_cols})
) subq
WHERE rand * {grp_size_str} <= {sample_size}
""".format(**locals())
plpy.execute(sql)
# splits is a list, each of whose elements is a dictionary.
# The dictionary contains 2 items:
# 1) grp_key - It is an array of text
# 2) con_splits - continuous split array
con_split_str = """{schema_madlib}._dst_compute_con_splits(
{con_features_str},
{n_per_seg}::integer,
{n_bins}::smallint)""".format(con_features_str=con_features_str,
schema_madlib=schema_madlib,
n_per_seg=n_per_seg_str,
n_bins=n_bins)
con_splits_all = plpy.execute(
""" SELECT
{con_split_str} AS con_splits,
{grouping_array_str} AS grp_key
FROM {sample_table_name}
GROUP BY {grouping_cols}
""".format(**locals()) # multiple rows
)
plpy.execute("DROP TABLE {sample_table_name}".format(**locals()))
if cat_features:
if is_classification:
# For classifications
order_fun = ("{schema_madlib}._dst_compute_entropy("
"{dependent_variable}, {n})".
format(schema_madlib=schema_madlib,
dependent_variable=dependent_variable,
n=dep_n_levels))
else:
order_fun = "avg({0})".format(dependent_variable)
filter_str = filter_null_dep + " AND {col} IS NOT NULL"
if null_proxy is None:
union_null_proxy = ""
else:
union_null_proxy = "UNION " + """
SELECT '{null_proxy}'::text as levels,
'Infinity'::double precision as dep_avg
FROM {training_table_name}
WHERE {{col}} IS NULL
GROUP BY {{col}}
""".format(**locals())
sql_cat_levels = """
SELECT
colname::text,
levels::text[],
grp_key::text
from (
SELECT
grp_key,
'{{col_name}}' as colname,
array_agg(levels order by dep_avg) as levels
from (
SELECT
{grouping_array_str} as grp_key,
({{col}})::text as levels,
{{order_fun}} as dep_avg
FROM {training_table_name}
WHERE {filter_str}
GROUP BY {{col}}, {grouping_cols}
) s
GROUP BY grp_key
) s1
""".format(**locals())
all_col_expressions = dict(zip(cat_features, explicit_bool_to_text(
training_table_name, cat_features, schema_madlib)))
sql_all_cats = ' UNION ALL '.join(
sql_cat_levels.format(
col=expr,
col_name=col_name,
order_fun=expr if col_name in ordered_cat_features else order_fun)
for col_name, expr in all_col_expressions.items())
all_levels = list(plpy.execute(sql_all_cats))
all_levels.sort(key=itemgetter('grp_key'))
# grp_col_to_levels is a list of tuples (pairs) with
# first value = group value,
# second value = a dict mapping a categorical column to its levels in data
# (these levels are specific to the group and can be different
# for different groups)
# The list of tuples can be converted to a dict, but the ordering
# will be lost.
# eg. grp_to_col_to_levels =
# [
# ('3', {'vs': [0, 1], 'cyl': [4,6,8]}),
# ('4', {'vs': [0, 1], 'cyl': [4,6]}),
# ('5', {'vs': [0, 1], 'cyl': [4,6,8]})
# ]
grp_to_col_to_levels = [
(grp_key, dict((row['colname'], row['levels']) for row in items))
for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))]
grp_to_cat_features = dict([(g, col_to_levels.keys())
for (g, col_to_levels) in grp_to_col_to_levels])
# Below statements collect the grp_to_col_to_levels into multiple variables
# From above eg.
# cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]]
# cat_n = [2, 3, 2, 2, 2, 3]
# cat_origin = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8]
# grp_key_cat = ['3', '4', '5']
cat_items_list = [rows[col]
for grp_key, rows in grp_to_col_to_levels
for col in cat_features if col in rows]
cat_n = [len(i) for i in cat_items_list]
cat_origin = [item for sublist in cat_items_list for item in sublist]
grp_key_cat = [item[0] for item in grp_to_col_to_levels]
else:
cat_n = []
cat_origin = []
grp_key_cat = [con_splits['grp_key'] for con_splits in con_splits_all]
grp_to_col_to_levels = [(con_splits['grp_key'], dict())
for con_splits in con_splits_all]
grp_to_cat_features = dict([(con_splits['grp_key'], list())
for con_splits in con_splits_all])
if con_features:
con = [con_splits['con_splits'] for con_splits in con_splits_all]
grp_key_con = [con_splits['grp_key'] for con_splits in con_splits_all]
else:
grp_key_con = [grp_key[0] for grp_key in grp_to_col_to_levels]
con = [''] * len(grp_key_con)
return dict(con=con,
con_features=con_features,
grp_key_con=grp_key_con,
cat_origin=cat_origin,
cat_n=cat_n,
cat_features=cat_features,
grp_key_cat=grp_key_cat,
grouping_array_str=grouping_array_str,
grp_to_col_to_levels=grp_to_col_to_levels,
)
# ------------------------------------------------------------
def _create_cat_features_info_table(cat_features_info_table, bins):
# bins['grp_to_col_to_levels'] =
# [
# ('3', {'vs': [0, 1], 'cyl': [4,6,8]}),
# ('4', {'vs': [0, 1], 'cyl': [4,6]}),
# ('5', {'vs': [0, 1]})
# ]
# Convert this into a VALUES command and place in a table
# VALUES (('3', ARRAY[2, 3], ARRAY['0', '1', '4', '6', '8']),
# ('4', ARRAY[2, 2], ARRAY['0', '1', '4', '6']),
# ('5', ARRAY[2], ARRAY['0', '1']),
# )
cat_features_info_values = []
if 'grp_to_col_to_levels' in bins:
# Grouping enabled, implies the cat levels can be different for
# different groups
for i, (grp_key, col_to_levels) in enumerate(bins['grp_to_col_to_levels'], start=1):
grp_key_str = quote_literal(grp_key)
cat_names_levels = [(c, col_to_levels[c]) for c in bins['cat_features']
if c in col_to_levels]
if cat_names_levels:
cat_names, cat_levels = zip(*cat_names_levels)
# categorical features in current group (expressed in an array)
cat_names_str = py_list_to_sql_string(
map(quote_literal, cat_names), 'text', long_format=True)
# number of levels in each cat feature
cat_n_levels_str = py_list_to_sql_string(
map(len, cat_levels), 'integer', long_format=True)
# flatten the levels across all cat features
cat_levels = [quote_literal(each_level)
for sublist in cat_levels
for each_level in sublist]
cat_levels_str = py_list_to_sql_string(cat_levels,
'text',
long_format=True)
else:
# this is the case if no categorical features present
cat_names_str = cat_n_levels_str = cat_levels_str = "NULL"
cat_features_info_values.append(
"({i}::INTEGER, {grp_key_str}::TEXT, {cat_names_str}::TEXT[], "
"{cat_n_levels_str}::INTEGER[], {cat_levels_str}::TEXT[])".
format(**locals()))
else:
# no grouping
if bins['cat_features']:
cat_names_str = py_list_to_sql_string(
map(quote_literal, bins['cat_features']), 'text', long_format=True)
cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', long_format=True)
cat_levels_str = py_list_to_sql_string(
map(quote_literal, bins['cat_origin']), 'text', long_format=True)
else:
cat_names_str = cat_n_levels_str = cat_levels_str = "NULL"
cat_features_info_values.append(
"(1::INTEGER, ''::TEXT, {0}::TEXT[], {1}::INTEGER[], {2}::TEXT[])".
format(cat_names_str, cat_n_levels_str, cat_levels_str))
sql_cat_features_info = """
CREATE TEMP TABLE {0} AS
SELECT *
FROM (
VALUES {1}
) AS q(gid, grp_key, cat_names, cat_n_levels, cat_levels_in_text)
""".format(cat_features_info_table,
',\n'.join(cat_features_info_values))
plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info)
plpy.execute(sql_cat_features_info.format(**locals()))
# ------------------------------------------------------------------------------
def get_feature_str(schema_madlib, source_table,
cat_features, con_features,
levels_str, n_levels_str,
null_proxy=None):
if len(cat_features) > 0:
# null_val is the replacement for NULL in categorial feature. If a
# null_proxy is set then the proxy is used to assign NULL as a valid
# category. If no proxy is available then NULL is replaced with a unique
# value. In a later step, the categorical levels are mapped to integers
# (1 to N). The unique value will be mapped to -1 indicating an
# unknown/missing value in the underlying layers.
null_val = unique_string() if null_proxy is None else null_proxy
# Cast boolean column to text: requires a special cast expression for
# platforms where __HAS_BOOL_TO_TEXT_CAST__ is not enabled
patched_cat_features = explicit_bool_to_text(source_table,
cat_features,
schema_madlib)
cat_features_cast = []
for col in patched_cat_features:
cat_features_cast.append(
"(coalesce(({0})::text, '{1}'))::text".format(col, null_val))
cat_features_str = ("{0}._map_catlevel_to_int(array[" +
", ".join(cat_features_cast) + "], {1}, {2}, {3})"
).format(schema_madlib,
levels_str,
n_levels_str,
null_proxy is not None)
else:
cat_features_str = "NULL"
if len(con_features) > 0:
con_features_list = ["COALESCE(" + str(con) + ", 'nan'::double precision)"
for con in con_features]
con_features_str = "ARRAY[" + ", ".join(con_features_list) + "]"
else:
con_features_str = "NULL"
return cat_features_str, con_features_str
# -------------------------------------------------------------------------
def _one_step(schema_madlib, training_table_name, cat_features,
con_features, boolean_cats, bins, n_bins, tree_state, weights,
dep_var, min_split, min_bucket, max_depth, filter_null,
dep_n_levels, subsample, n_random_features,
max_n_surr=0, null_proxy=None):
""" One step of tree training
@param tree_state A big double precision array that conatins
(1) internal node: column and split value
(2) leaf node: the statistics described as a series of numbers
"""
# The function _map_catlevel_to_int maps a categorical variable value to its
# integer representation. It returns an integer array.
# XXX cat_feature_str contains $5 and $2, and a SQL function
bytea8 = schema_madlib + '.bytea8'
cat_features_str, con_features_str = get_feature_str(schema_madlib,
training_table_name,
cat_features,
con_features,
"$3", "$2",
null_proxy)
train_sql = """
SELECT (result).* from (
SELECT
{schema_madlib}._dt_apply(
$1,
{schema_madlib}._compute_leaf_stats(
$1, -- current tree state, madlib.bytea8
{cat_features_str}, -- categorical features in an array
{con_features_str}, -- continuous features in an array
{dep_var},
{weights}, -- weight value
$2, -- categorical sorted levels in a combined array
$4, -- continuous splits
{dep_n_levels}::smallint, -- number of dependent levels
{subsample}::boolean -- should we use a subsample of data
),
$4,
{min_split}::smallint,
{min_bucket}::smallint,
{max_depth}::smallint,
{subsample}::boolean,
{n_random_features}::integer
) as result
FROM {training_table_name}
WHERE {filter_null}
) s
""".format(**locals())
train_sql_plan = plpy.prepare(train_sql, [bytea8, 'integer[]', 'text[]', bytea8])
# return a new tree state
updated_tree = plpy.execute(train_sql_plan, [tree_state['tree_state'],
bins['cat_n'],
bins['cat_origin'],
bins['con']
])[0]
# Compute surrogates:
# tree_depth outside the scope of dt_apply starts from 0 i.e. 0 depth
# implies a single leaf node for the tree. We compute surrogates only for
# internal nodes. Hence we at least need the root be an internal node i.e.
# we need the tree_depth to be 1 or more.
if max_n_surr > 0 and updated_tree['tree_depth'] > 0:
dup_count_expr = weights + "::integer" if subsample else '1'
surr_sql = """
SELECT
{schema_madlib}._dt_surr_apply(
$1,
{schema_madlib}._compute_surr_stats(
$1,
{cat_features_str},
{con_features_str},
$2,
$4,
{dup_count_expr}),
$4) as result
FROM {training_table_name}
WHERE {filter_null}
""".format(**locals())
surr_sql_plan = plpy.prepare(surr_sql, [bytea8, 'integer[]', 'text[]', bytea8])
# return a new tree state
updated_tree['tree_state'] = plpy.execute(surr_sql_plan,
[updated_tree['tree_state'],
bins['cat_n'],
bins['cat_origin'],
bins['con']
])[0]["result"]
return updated_tree
# ------------------------------------------------------------
def _one_step_for_grps(
schema_madlib, training_table_name, cat_features,
con_features, boolean_cats, bins, n_bins, tree_states, weights,
grouping_cols, grouping_array_str, dep_var, min_split, min_bucket,
max_depth, filter_null, dep_n_levels, subsample, n_random_features,
cat_features_info_table, max_n_surr=0, null_proxy=None):
""" One step of trees training with grouping support
"""
# The function _map_catlevel_to_int maps a categorical variable value to its
# integer representation. It returns an integer array.
# XXX cat_feature_str contains $5 and $2, and a SQL function
bytea8arr = schema_madlib + '.bytea8[]'
# avoid name conflicts
grp_key = unique_string()
tree_state = unique_string()
con_splits = unique_string()
cat_n_levels = unique_string()
finished = unique_string()
cat_levels_in_text = unique_string()
cat_features_str, con_features_str = get_feature_str(
schema_madlib, training_table_name, cat_features, con_features,
cat_levels_in_text, cat_n_levels, null_proxy)
train_apply_func = """
{schema_madlib}._dt_apply(
{tree_state},
agg_result,
{con_splits},
{min_split}::smallint,
{min_bucket}::smallint,
{max_depth}::smallint,
{subsample}::boolean,
{n_random_features}::integer)
""".format(**locals())
train_aggregate = """
{schema_madlib}._compute_leaf_stats(
{tree_state},
{cat_features_str},
{con_features_str},
{dep_var},
{weights},
{cat_n_levels},
{con_splits},
{dep_n_levels}::smallint,
{subsample}::boolean)
""".format(**locals())
sql = """
SELECT
s1.grp_key,
{apply_func} as result
FROM ( SELECT
{grouping_array_str} AS grp_key,
{aggregate} AS agg_result
FROM
{training_table_name} as src,
( SELECT
grp_key AS {grp_key},
finished AS {finished},
tree_state AS {tree_state},
con_splits AS {con_splits},
cat_n_levels::INTEGER[] AS {cat_n_levels},
cat_levels_in_text::TEXT[] AS {cat_levels_in_text}
FROM
( SELECT
unnest($1) AS grp_key,
unnest($2) AS finished,
unnest($3) AS tree_state
) AS tree_state_set
JOIN (
SELECT
unnest($4) AS grp_key,
unnest($5) AS con_splits
) AS con_splits
USING (grp_key)
JOIN
{cat_features_info_table}
USING (grp_key)
) AS needed_data
WHERE {grouping_array_str} = {grp_key}
AND {filter_null}
GROUP BY {grouping_cols}
) s1
JOIN ( SELECT
grp_key,
tree_state AS {tree_state},
con_splits AS {con_splits}
FROM ( SELECT
unnest($1) AS grp_key,
unnest($3) AS tree_state
) AS tree_state_set
JOIN
( SELECT
unnest($4) AS grp_key,
unnest($5) AS con_splits
) AS con_splits
USING (grp_key)
) s2
USING (grp_key)
"""
train_sql = "SELECT grp_key, (result).* FROM (" + sql + ") sub"
train_sql = train_sql.format(aggregate=train_aggregate,
apply_func=train_apply_func,
# check_finished="AND " + finished + " = 0",
**locals())
train_sql_plan = plpy.prepare(train_sql,
['text[]', 'integer[]', bytea8arr,
'text[]', bytea8arr])
unfinished_trees = [t for t in tree_states if t['finished'] == 0]
finished_trees = [t for t in tree_states if t['finished'] != 0]
# Need to cast the output of plpy.execute to 'list' since newer versions
# of Postgres return a PlyResult object instead of a list
updated_unfinished = list(plpy.execute(train_sql_plan, [
[t['grp_key'] for t in unfinished_trees],
[t['finished'] for t in unfinished_trees],
[t['tree_state'] for t in unfinished_trees],
bins['grp_key_con'],
bins['con']]))
if max_n_surr > 0:
surr_apply_func = """
{schema_madlib}._dt_surr_apply(
{tree_state},
agg_result,
{con_splits})
""".format(schema_madlib=schema_madlib,
tree_state=tree_state,
con_splits=con_splits)
dup_count_expr = weights + "::integer" if subsample else '1'
surr_aggregate = """
{schema_madlib}._compute_surr_stats(
{tree_state},
{cat_features_str},
{con_features_str},
{cat_n_levels},
{con_splits},
{dup_count_expr})
""".format(**locals())
# update the tree to compute surrogates for the layer above the just
# computed leaves
surr_sql = sql.format(aggregate=surr_aggregate,
apply_func=surr_apply_func,
check_finished="",
**locals())
surr_sql_plan = plpy.prepare(surr_sql,
['text[]', 'integer[]', bytea8arr,
'text[]', bytea8arr])
surr_trees = list(plpy.execute(surr_sql_plan, [
[t['grp_key'] for t in updated_unfinished],
[t['finished'] for t in updated_unfinished],
[t['tree_state'] for t in updated_unfinished],
bins['grp_key_con'],
bins['con']]))
surr_dict = dict()
for each_tree in surr_trees:
surr_dict[each_tree['grp_key']] = each_tree['result']
for each_tree in updated_unfinished:
if each_tree['grp_key'] in surr_dict:
each_tree['tree_state'] = surr_dict[each_tree['grp_key']]
# return all tree states (finished and unfinished)
updated_unfinished.extend(finished_trees)
return updated_unfinished
# ------------------------------------------------------------
def _create_grp_result_table(
schema_madlib, tree_states, bins, cat_features,
con_features, output_table_name, cat_features_info_table,
grouping_cols,
training_table_name, use_existing_tables=False,
running_cv=False, k=0):
""" Create the output table for grouping case.
"""
grp_key = unique_string()
tree_state = unique_string()
tree_depth = unique_string()
cp_col = unique_string()
cat_levels_in_text = unique_string()
cat_n_levels = unique_string()
impurity_var_importance = unique_string()
cat_levels_val = cat_levels_in_text if cat_features else "NULL::TEXT[]"
cat_n_levels_val = cat_n_levels if cat_features else "NULL::INTEGER[]"
grouping_array_str = bins['grouping_array_str']
fold = ", " + str(k) + " as k" if running_cv else ""
if use_existing_tables:
# plpy.execute("truncate " + output_table_name)
header = "insert into " + output_table_name + " "
else:
header = "create table " + output_table_name + " as "
sql = header + """
SELECT
{grouping_cols},
{tree_state} AS tree,
{cat_levels_val} AS cat_levels_in_text,
{cat_n_levels_val} AS cat_n_levels,
string_to_array(trim(both '{{}}' FROM {impurity_var_importance}),
',')::double precision[] AS impurity_var_importance,
{tree_depth} AS tree_depth,
{cp_col} AS pruning_cp
{fold}
FROM (
SELECT
{grouping_cols},
{grouping_array_str} AS {grp_key}
FROM {training_table_name}
group by {grouping_cols}
) s1
JOIN (
SELECT
unnest($1) AS {grp_key},
unnest($2) AS {tree_state},
unnest($3) AS {tree_depth},
unnest($4) AS {cp_col},
unnest($5) AS {impurity_var_importance}
) s2
USING ({grp_key})
"""
if cat_features:
sql += """
JOIN (
SELECT
grp_key as {grp_key},
cat_n_levels as {cat_n_levels},
cat_levels_in_text as {cat_levels_in_text}
FROM
{cat_features_info_table}
) s3
USING ({grp_key})
"""
sql = sql.format(**locals())
prepare_list = ['text[]',
'{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
'integer[]', 'double precision[]', 'text[]']
execute_list = [
[t['grp_key'] for t in tree_states],
[t['tree_state'] for t in tree_states],
[t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
for t in tree_states],
[t['cp'] for t in tree_states],
[_array_to_string(t['impurity_var_importance']) for t in tree_states]]
if cat_features:
prepare_list += ['text[]', 'integer[]', 'integer', 'text[]']
execute_list += [
bins['grp_key_cat'],
bins['cat_n'],
len(cat_features),
bins['cat_origin']]
sql_plan = plpy.prepare(sql, prepare_list)
plpy.execute(sql_plan, execute_list)
# ------------------------------------------------------------
def _get_dep_type(data_table, dep):
"""
@brief Obtain the dependent_variable type
"""
table_schema_str, table_name = _get_table_schema_names(data_table)
dep_type = plpy.execute("""
SELECT data_type from information_schema.columns
where
table_name = '{table_name}' and
table_schema in {table_schema} and
column_name = '{dep}'
""".format(table_name=table_name,
table_schema=table_schema_str,
dep=dep))
if dep_type:
return dep_type[0]['data_type']
else:
return get_expr_type(dep, data_table)
# ------------------------------------------------------------
def _create_summary_table(
schema_madlib, split_criterion,
training_table_name, output_table_name, id_col_name,
cat_features, con_features, dependent_variable, list_of_features,
list_of_features_to_exclude,
num_failed_groups, is_classification, n_all_rows, n_rows,
dep_list, all_cols_types, cp, grouping_cols=None, n_groups=1,
use_existing_tables=False, n_folds=0, null_proxy=None):
output_table_summary = add_postfix(output_table_name, "_summary")
# dependent variables
dep_type = _get_dep_type(training_table_name, dependent_variable)
if dep_list:
if is_psql_boolean_type(dep_type):
# Special handling for boolean since Python booleans start with
# capitals (i.e False instead of false)
# Note: dep_list is sorted, hence 'false' will be first
dep_list_str = "'false, true'"
else:
dep_list_str = '$__dep_list__${0}$__dep_list__$'.format(
','.join(map(str, dep_list)))
else:
dep_list_str = "NULL"
indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features)
independent_varnames = ','.join(cat_features + con_features)
cat_features_str = ','.join(cat_features)
con_features_str = ','.join(con_features)
if grouping_cols:
grouping_cols_str = "'{0}'".format(grouping_cols)
else:
grouping_cols_str = "NULL"
n_rows_skipped = n_all_rows - n_rows
if isinstance(cp, Iterable):
cp_str = py_list_to_sql_string(cp, 'double precision')
else:
cp_str = str(cp) + "::double precision"
if use_existing_tables:
plpy.execute("TRUNCATE " + output_table_summary)
header = "INSERT INTO {0} ".format(output_table_summary)
else:
header = "CREATE TABLE {0} AS ".format(output_table_summary)
null_proxy_str="NULL" if null_proxy is None else "'{0}'".format(null_proxy)
sql = header + """
SELECT
'tree_train'::text AS method,
'{is_classification}'::boolean AS is_classification,
'{training_table_name}'::text AS source_table,
'{output_table_name}'::text AS model_table,
'{id_col_name}'::text AS id_col_name,
'{list_of_features}'::text AS list_of_features,
'{list_of_features_to_exclude}'::text AS list_of_features_to_exclude,
'{dependent_variable}'::text AS dependent_varname,
'{independent_varnames}'::text AS independent_varnames,
'{cat_features_str}'::text AS cat_features,
'{con_features_str}'::text AS con_features,
{grouping_cols_str}::text AS grouping_cols,
{n_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{n_rows}::integer AS total_rows_processed,
{n_rows_skipped}::integer AS total_rows_skipped,
{dep_list_str}::text AS dependent_var_levels,
'{dep_type}'::text AS dependent_var_type,
'{indep_type}'::text AS independent_var_types,
{cp_str} AS input_cp,
{n_folds}::integer AS n_folds,
{null_proxy_str}::text AS null_proxy
""".format(**locals())
plpy.execute(sql)
# ------------------------------------------------------------
def _get_filter_str(dependent_variable, grouping_cols):
""" Return a 'WHERE' clause string that filters out all rows that contain a
NULL.
"""
if grouping_cols:
group_filter = ' and '.join('({0}) is not NULL'.format(g.strip())
for g in grouping_cols.split(','))
else:
group_filter = None
dep_filter = '(' + dependent_variable + ") is not NULL"
return ' and '.join(filter(None, [group_filter, dep_filter]))
# -------------------------------------------------------------------------
def _validate_predict(schema_madlib, model, source, output, use_existing_tables):
# validations for inputs
_assert(source and source.strip().lower() not in ('null', ''),
"Decision tree error: Invalid data table name: {0}".format(source))
_assert(table_exists(source),
"Decision tree error: Data table ({0}) does not exist".format(source))
_assert(not table_is_empty(source),
"Decision tree error: Data table ({0}) is empty".format(source))
_assert(model and
model.strip().lower() not in ('null', ''),
"Decision tree error: Invalid model table name: {0}".format(model))
_assert(table_exists(model),
"Decision tree error: Model table ({0}) does not exist".format(model))
_assert(not table_is_empty(model),
"Decision tree error: Model table ({0}) is empty".format(model))
model_summary = add_postfix(model, "_summary")
_assert(table_exists(model_summary),
"Decision tree error: Model summary table ({0}) does not exist".format(model_summary))
_assert(not table_is_empty(model_summary),
"Decision tree error: Model summary table ({0}) is empty".format(model_summary))
_assert(output and
output.strip().lower() not in ('null', ''),
"Decision tree error: Invalid output table name: {0}".format(output))
if not use_existing_tables:
_assert(not table_exists(output),
"Decision tree error: Output table ({0}) already exists".format(output))
_assert(
columns_exist_in_table(
model,
["tree", "cat_levels_in_text", "cat_n_levels"],
schema_madlib),
"Decision tree error: Invalid model table ({0})".format(model))
_assert(
columns_exist_in_table(
model_summary,
["grouping_cols", "id_col_name", "dependent_varname",
"cat_features", "con_features", "is_classification"],
schema_madlib),
"Decision tree error: Invalid model summary table ({0})".format(model_summary))
# -------------------------------------------------------------------------
def tree_predict(schema_madlib, model, source, output, pred_type='response',
use_existing_tables=False, k=0, **kwargs):
"""
Args:
@param schema_madlib: str, Name of MADlib schema
@param model: str, Name of table containing the tree model
@param source: str, Name of table containing prediction data
@param output: str, Name of table to output the results
@param pred_type: str, The type of output required:
'response' gives the actual response values,
'prob' gives the probability of the classes in a
classification tree.
For regression tree, only type='response' is defined.
Returns:
None
Side effect:
Creates an output table containing the prediction for given source table
Throws:
None
"""
_validate_predict(schema_madlib, model, source, output, use_existing_tables)
model_summary = add_postfix(model, "_summary")
# obtain the cat_features and con_features from model summary table
summary_elements = plpy.execute("SELECT * FROM {0}".format(model_summary))[0]
list_of_features = split_quoted_delimited_str(summary_elements["independent_varnames"])
cat_features = split_quoted_delimited_str(summary_elements["cat_features"])
con_features = split_quoted_delimited_str(summary_elements["con_features"])
_assert(is_var_valid(source, ','.join(list_of_features)),
"Decision tree error: Missing columns in predict data table ({0}) "
"that were used during training".format(source))
id_col_name = summary_elements["id_col_name"]
dep_varname = summary_elements["dependent_varname"]
dep_levels = split_quoted_delimited_str(summary_elements["dependent_var_levels"])
dep_type = summary_elements['dependent_var_type']
is_classification = summary_elements["is_classification"]
# optional variables, default value is None
grouping_cols_str = summary_elements.get("grouping_cols")
null_proxy = summary_elements.get('null_proxy')
# find which columns are of type boolean
boolean_cats = set([key for key, value in get_cols_and_types(source)
if value == 'boolean'])
cat_features_str, con_features_str = get_feature_str(
schema_madlib, source, cat_features, con_features,
"m.cat_levels_in_text", "m.cat_n_levels", null_proxy)
if use_existing_tables and table_exists(output):
plpy.execute("TRUNCATE " + output)
header = "INSERT INTO " + output
use_fold = 'WHERE k = ' + str(k)
else:
header = "CREATE TABLE " + output + " AS "
use_fold = ''
if not grouping_cols_str:
using_str = ""
join_str = ","
else:
using_str = "USING ( " + grouping_cols_str + ")"
join_str = "LEFT OUTER JOIN"
pred_name = ('"prob_{0}"' if pred_type == "prob" else
'"estimated_{0}"').format(dep_varname.replace('"', '').strip())
if not is_classification:
sql = header + """
SELECT {id_col_name},
{schema_madlib}._predict_dt_response(
tree,
{cat_features_str}::INTEGER[],
{con_features_str}::DOUBLE PRECISION[]) as {pred_name}
FROM {source} as s
{join_str}
{model} as m
{using_str}
{use_fold}
"""
else:
if is_psql_boolean_type(dep_type):
# some platforms don't have text to boolean cast. We manually check the string.
dep_cast_str = ("(case {pred_name} when 'true' then true "
" when 'false' then false "
"end)::BOOLEAN as {pred_name}")
else:
dep_cast_str = "{pred_name}::{dep_type}"
dep_levels_array_str = py_list_to_sql_string(map(quote_literal, dep_levels),
'TEXT',
long_format=True)
if pred_type == "response":
sql = header + """
SELECT
{id_col_name},
%s
FROM (
SELECT
{id_col_name},
-- _predict_dt_response returns 0-based indexing.
-- Hence the "+ 1" (DB by default uses 1-based indexing)
(%s)[{schema_madlib}._predict_dt_response (
tree,
{cat_features_str}::INTEGER[],
{con_features_str}::DOUBLE PRECISION[]) + 1]::TEXT
as {pred_name}
FROM {source} as s {join_str} {model} as m {using_str}
{use_fold}
) q
""" % (dep_cast_str, dep_levels_array_str)
else:
temp_col = unique_string()
score_format = ', \n'.join([
'{0}[{1}] as "estimated_prob_{2}"'.format(temp_col, i, c.strip(' "'))
for i, c in enumerate(dep_levels, start=1)])
sql = header + """
SELECT
{id_col_name},
%s
FROM (
SELECT {id_col_name},
{schema_madlib}._predict_dt_prob(tree,
{cat_features_str}::INTEGER[],
{con_features_str}::DOUBLE PRECISION[])
AS {temp_col}
FROM {source} as s {join_str} {model} as m {using_str}
{use_fold}
) q
""" % (score_format)
sql = sql.format(**locals())
with MinWarning('warning'):
with OptimizerControl(False):
# we disable optimizer (ORCA) for platforms that use it
# since ORCA doesn't provide an easy way to disable hashagg
with HashaggControl(False):
# we disable hashagg since large number of groups could
# result in excessive memory usage.
plpy.execute(sql)
return None
# -------------------------------------------------------------------------
def _get_display_header(table_name, dep_levels, is_regression, dot_format=True):
if dot_format:
tree_type = ("Classification", "Regression")[is_regression]
return ("digraph {0} {{".format('"' + tree_type + ' tree for ' +
str(table_name) + '"'))
else:
return_str = """-------------------------------------
- Each node represented by 'id' inside ().
- Each internal nodes has the split condition at the end, while each
leaf node has a * at the end.
- For each internal node (i), its child nodes are indented by 1 level
with ids (2i+1) for True node and (2i+2) for False node.
"""
if is_regression:
return_str += ("- Number of rows and average response value inside []. "
"For a leaf node, this is the prediction.\n")
else:
return_str += """- Number of (weighted) rows for each response variable inside [].'
The response label order is given as {0}.
For each leaf, the prediction is given after the '-->'
""".format(str(dep_levels))
return_str += "\n-------------------------------------"
return return_str
# ------------------------------------------------------------------------------
def tree_display(schema_madlib, model_table, dot_format=True, verbose=False,
disp_surr=False, **kwargs):
if dot_format:
disp_surr = False # surrogates cannot be displayed in dot format
bytea8 = schema_madlib + '.bytea8'
summary_table = add_postfix(model_table, "_summary")
summary = plpy.execute("SELECT * FROM {summary_table}".
format(summary_table=summary_table))[0]
dep_levels = summary["dependent_var_levels"]
dep_levels = [''] if not dep_levels else split_quoted_delimited_str(dep_levels)
table_name = summary["source_table"]
is_regression = not summary["is_classification"]
cat_features_str = split_quoted_delimited_str(summary['cat_features'])
con_features_str = split_quoted_delimited_str(summary['con_features'])
grouping_cols = summary["grouping_cols"]
grouping_cols = '' if grouping_cols is None else grouping_cols + ','
grouped_trees = plpy.execute("SELECT {grouping_cols} "
"tree, cat_levels_in_text, cat_n_levels "
"FROM {model_table}".
format(model_table=model_table,
grouping_cols=grouping_cols))
return_str_list = []
if not disp_surr:
return_str_list.append(_get_display_header(table_name, dep_levels,
is_regression, dot_format))
else:
return_str_list.append("""
-------------------------------------
Surrogates for internal nodes
-------------------------------------
""")
with MinWarning('warning'):
for index, dtree in enumerate(grouped_trees):
grouping_keys = [k + ' = ' + str(dtree[k]) for k in dtree.keys()
if k not in ('tree', 'cat_features', 'con_features',
'cat_n_levels', 'cat_levels_in_text')]
if grouping_keys:
group_name = "(" + ','.join(grouping_keys) + ")"
else:
group_name = ''
tree = dtree['tree']
if dtree['cat_levels_in_text']:
cat_levels_in_text = dtree['cat_levels_in_text']
cat_n_levels = dtree['cat_n_levels']
else:
cat_levels_in_text = []
cat_n_levels = []
if disp_surr:
if group_name:
return_str_list.append("--- Surrogates for tree {0} ---".format(group_name))
sql = """SELECT {0}._display_decision_tree_surrogate(
$1, $2, $3, $4, $5) as display_tree
""".format(schema_madlib)
# execute sql to get display string
sql_plan = plpy.prepare(sql, [bytea8,
'text[]', 'text[]', 'text[]',
'int[]'])
tree_display = plpy.execute(
sql_plan, [tree, cat_features_str, con_features_str,
cat_levels_in_text, cat_n_levels])[0]
else:
if dot_format:
return_str_list.append('\t subgraph "cluster{0}"{{'.format(index))
return_str_list.append('\t label="{0}"'.format(group_name.replace('"', '\\"')))
sql = """
SELECT {0}._display_decision_tree(
$1, $2, $3, $4, $5, $6, '{1}', {2}
) as display_tree
""".format(schema_madlib, "g" + str(index) + "_", verbose)
else:
if group_name:
return_str_list.append("--- Tree for {0} ---".format(group_name))
sql = """
SELECT {0}._display_text_decision_tree(
$1, $2, $3, $4, $5, $6) as display_tree
""".format(schema_madlib)
sql_plan = plpy.prepare(sql, [bytea8,
'text[]', 'text[]', 'text[]',
'int[]', 'text[]'])
tree_display = plpy.execute(
sql_plan, [tree, cat_features_str, con_features_str,
cat_levels_in_text, cat_n_levels, dep_levels])[0]
return_str_list.append(tree_display["display_tree"])
if dot_format:
return_str_list.append("\t } //--- end of subgraph------------")
else:
return_str_list.append("-------------------------------------")
if dot_format:
return_str_list.append("} //---end of digraph--------- ")
return ("\n".join(return_str_list))
# -------------------------------------------------------------------------
def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False):
""" Prune tree with given cost-complexity parameters
and return a list of cp values at which tree can be pruned
Args:
@param schema_madlib: str, MADlib schema name
@param tree: Tree data to prune
@param cp: float, cost-complexity parameter, all splits that have a
complexity lower than 'cp' will be pruned
@param compute_cp_list: bool, optionally return a list of cp values that
are the various complexity boundaries for the splits
in the tree. This list can be used by cross-validation
to explore different cp values.
Returns:
Dictionary containing following keys:
tree_state: pruned tree state
pruned_depth: depth of tree after pruning
cp_list: list of cp values at which tree can be pruned
(returned only if compute_cp_list=True)
"""
sql = """
SELECT (pruned_tree).*
FROM (
SELECT {madlib}._prune_and_cplist(
$1,
({cp})::double precision,
({compute_cp_list})::boolean
) as pruned_tree
) q
""".format(madlib=schema_madlib, cp=cp,
compute_cp_list=bool(compute_cp_list))
sql_plan = plpy.prepare(sql, [schema_madlib + '.bytea8'])
pruned_tree = plpy.execute(sql_plan, [tree['tree_state']])[0]
return pruned_tree
# -------------------------------------------------------------------------
def _compute_var_importance(schema_madlib, tree,
n_cat_features, n_con_features):
""" Compute variable importance for categorical and continuous features
Args:
@param schema_madlib: str, MADlib schema name
@param tree: dict. tree['tree_state'] is the trained tree (in byte form)
@param n_cat_features: int, Number of categorical features
@param n_con_features: int, Number of continuous features
Returns:
Dictionary containing following keys:
impurity_var_importance: Array of importance values
"""
var_imp_sql = """
SELECT {schema_madlib}._compute_var_importance(
$1, -- trained decision tree
{n_cat_features},
{n_con_features}) AS impurity_var_importance
""".format(**locals())
var_imp_plan = plpy.prepare(var_imp_sql, [schema_madlib + '.bytea8'])
return plpy.execute(var_imp_plan, [tree['tree_state']])[0]
# ------------------------------------------------------------------------------
def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_name,
id_col_name, dependent_variable,
list_of_features, list_of_features_to_exclude,
cat_features, ordered_cat_features, boolean_cats, con_features,
split_criterion, grouping_cols, weights, max_depth,
min_split, min_bucket, n_bins, is_classification,
dep_is_bool, dep_n_levels, n_folds, n_rows,
max_n_surr, null_proxy, msg_level='warning', **kwargs):
"""
Run cross validation for decision tree over multiple cp values
Returns:
cp value corresponding to lowest cross validation error
Side effect:
Creates a table containing cross-validation error for each cp value
"""
# 1) create group_to_param_list_table for CV
group_to_param_list_table = unique_string()
param_list_name = unique_string()
if not grouping_cols:
plan = plpy.prepare("""
CREATE TEMP TABLE {group_to_param_list_table} AS
SELECT $1 AS {param_list_name}
""".format(**locals()), ["float8[]"])
plpy.execute(plan, [tree_states[0]['cp_list']])
plpy.notice("Running cross validation for cp values: {0}".format(
str(tree_states[0]['cp_list'])))
else:
grp_list = []
cp_list = []
plpy.notice("Running cross validation for cp values:")
for tree in tree_states:
grp_list.extend([tree['grp_key']]*len(tree['cp_list']))
cp_list.extend(tree['cp_list'])
plpy.notice("{0} -> {1}".format(tree['grp_key'], str(tree['cp_list'])))
grouping_array_str = get_grouping_array_str(training_table_name,
grouping_cols)
grp_key_str = unique_string()
plan = plpy.prepare("""
CREATE TEMP TABLE {group_to_param_list_table} AS
SELECT
{grouping_cols},
{param_list_name}
FROM
(
SELECT {grp_key_str},
array_agg(cp_value) as {param_list_name}
FROM
(
SELECT unnest($1) as {grp_key_str},
unnest($2) as cp_value
) unnested_grp_cp
GROUP BY {grp_key_str}
) grp_key_to_cp_list
JOIN
(
SELECT
{grouping_cols},
{grouping_array_str} as {grp_key_str}
FROM {training_table_name}
) grp_key_to_grouping_cols
USING ({grp_key_str})
""".format(**locals()), ["text[]", "float8[]"])
plpy.execute(plan, [grp_list, cp_list])
# 2) call CV function to actually cross-validate _build_tree
# expects output table model_cv({grouping_cols), cp, avg, stddev)
model_cv = output_table_name + "_cv"
metric_function = "_tree_misclassified" if is_classification else "_tree_rmse"
pred_name = '"estimated_{0}"'.format(dependent_variable.strip(' "'))
grouping_str = 'NULL' if not grouping_cols else '"' + grouping_cols + '"'
all_features = [cat_features, ordered_cat_features, boolean_cats, con_features]
# _get_xvalidate_params builds the parameters used in
# DT train, predict, distance functions. Single quotes are added in these
# parameters (except for the feature arrays) since we run
# cross_validation_grouping_w_params with `add_param_quotes=False'.
# This special handling is put in place to ensure the feature arrays are
# treated as arrays instead of strings.
xvalidate_params = _get_xvalidate_params(**locals())
cross_validation_grouping_w_params(
schema_madlib,
schema_madlib + '.__build_tree',
xvalidate_params[0],
xvalidate_params[1],
schema_madlib + '.__tree_predict',
xvalidate_params[2],
xvalidate_params[3],
schema_madlib + "." + metric_function,
xvalidate_params[4],
xvalidate_params[5],
group_to_param_list_table, param_list_name, grouping_cols,
training_table_name, id_col_name, False,
model_cv, 'cp', None, n_folds,
add_param_quotes=False)
# 3) find the best cp for each group from table {model_cv}
if not grouping_cols:
grouping_array_str_comma = "''::TEXT AS grp_key,"
group_by_str = ''
qualified_group_by_str = ''
group_select_str = ''
else:
grouping_array_qualified_str = get_grouping_array_str(training_table_name,
grouping_cols,
qualifier="min_cv_error")
grouping_array_str_comma = grouping_array_qualified_str + " AS grp_key,"
grouping_cols_list = grouping_cols.split(",")
qualified_group_by_str = "GROUP BY " + ','.join("min_cv_error." + i for i in grouping_cols_list)
group_by_str = "GROUP BY " + grouping_cols
group_select_str = grouping_cols + ","
plpy.notice(str(list(plpy.execute("SELECT * FROM {0}".format(model_cv)))))
validation_result_query = """
SELECT
{grouping_array_str_comma}
max(cp) AS cp
FROM
{model_cv}
NATURAL JOIN
( SELECT
{group_select_str}
min(cv_error_avg) AS cv_error_avg
FROM {model_cv}
{group_by_str}
) min_cv_error
{qualified_group_by_str}
""".format(**locals())
plpy.notice("validation_result_query:")
plpy.notice(validation_result_query)
validation_result = plpy.execute(validation_result_query)
plpy.notice("finished validation_result_query, validation_result = " + str(list(validation_result)))
grp_key_to_best_cp = dict((row['grp_key'], row['cp']) for row in validation_result)
# 4) update tree_states to have the best cp cross-validated
for tree in tree_states:
best_cp = grp_key_to_best_cp[tree['grp_key']]
if best_cp > tree['cp']:
tree['cp'] = best_cp
# we prune each tree further using a higher cp value
# giving the optimal pruned tree.
# This time we don't need the cp_list.
pruned_tree = _prune_and_cplist(schema_madlib,
tree,
tree['cp'],
compute_cp_list=False)
tree['tree_state'] = pruned_tree['tree_state']
if 'pruned_depth' in pruned_tree:
tree['pruned_depth'] = pruned_tree['pruned_depth']
elif 'tree_depth' in pruned_tree:
tree['pruned_depth'] = pruned_tree['tree_depth']
else:
tree['pruned_depth'] = 0
importance_vectors = _compute_var_importance(
schema_madlib, tree,
len(cat_features), len(con_features))
tree.update(**importance_vectors)
plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals()))
# ------------------------------------------------------------
def _get_xvalidate_params(**kwargs):
""" Build train, predict, and metric parameters for cross_validation
Args:
@param all_features
Returns:
"""
def _list_to_string_to_array(array_input):
""" Return a string that can interpreted by postgresql as text[] containing
the names in array_input
Example:
Input: ['"Cont_features"[1]', '"Cont_features"[2]']
Output: string_to_array('"Cont_features"[1]~^~"Cont_features"[2]'::text, '~^~');
When this output is executed by Postgresql it creates a text array:
madlib=# select string_to_array('"Cont_features"[1]~^~"Cont_features"[2]'::text, '~^~')::VARCHAR[] as t;
t
-------------------------------------------------
{"\"Cont_features\"[1]","\"Cont_features\"[2]"}
(1 row)
"""
if not array_input:
return "'{}'"
return "string_to_array('{0}', '~^~')".format('~^~'.join(array_input))
all_feature_str = [_list_to_string_to_array(i) for i in kwargs['all_features']]
def _add_quote(s):
if s is None:
return None
s = str(s)
return "NULL" if s.upper() == 'NULL' else "'{0}'".format(s)
quoted_args = {}
for k, v in kwargs.items():
quoted_args[k] = _add_quote(v)
modeling_params = [quoted_args['is_classification'],
quoted_args['split_criterion'],
"%data%",
"%model%",
quoted_args['id_col_name'],
quoted_args['dependent_variable'],
quoted_args['dep_is_bool'],
quoted_args['list_of_features'],
all_feature_str[0],
all_feature_str[1],
all_feature_str[2],
all_feature_str[3],
quoted_args['grouping_str'],
quoted_args['weights'],
quoted_args['max_depth'],
quoted_args['min_split'],
quoted_args['min_bucket'],
quoted_args['n_bins'],
"%explore%",
quoted_args['max_n_surr'],
quoted_args['msg_level'],
quoted_args['null_proxy']
]
modeling_param_types = (["BOOLEAN"] + ["TEXT"] * 5 + ["BOOLEAN"] + ["TEXT"] +
["VARCHAR[]"] * 4 + ["TEXT"] * 2 + ["INTEGER"] * 4 +
["TEXT", "SMALLINT", "TEXT", "TEXT"])
predict_params = ["%model%", "%data%", "%prediction%", "'response'", "True"]
predict_param_types = (["VARCHAR"] * 4 + ["BOOLEAN"])
metric_params = ["%data%",
quoted_args['dependent_variable'],
"%prediction%",
quoted_args['pred_name'],
quoted_args['id_col_name'],
quoted_args['grouping_cols'],
"%error%",
"True"]
metric_param_types = ["VARCHAR", "VARCHAR", "VARCHAR", "VARCHAR",
"VARCHAR", "TEXT", "VARCHAR", "BOOLEAN"]
return [modeling_params, modeling_param_types,
predict_params, predict_param_types,
metric_params, metric_param_types]
# ----------------------------------------------------------------------
def _tree_train_using_bins(
schema_madlib, bins, training_table_name,
cat_features, con_features, boolean_cats, n_bins, weights,
dep_var_str, min_split, min_bucket, max_depth, filter_dep,
dep_n_levels, is_classification, split_criterion,
subsample=False, n_random_features=1, max_n_surr=0, null_proxy=None,
**kwargs):
"""Trains a tree without grouping columns"""
# Iterations for training the tree
tree_state = plpy.execute(
"""
SELECT {schema_madlib}._initialize_decision_tree(
{is_regression_tree},
'{split_criterion}'::text,
{dep_n_levels}::smallint,
{max_n_surr}::smallint
) AS tree_state,
FALSE as finished
""".format(schema_madlib=schema_madlib,
is_regression_tree=(not is_classification),
split_criterion=split_criterion,
dep_n_levels=dep_n_levels,
max_n_surr=max_n_surr))[0]
plpy.notice("Starting tree building")
tree_depth = -1
while tree_state['finished'] == 0:
# finished: 0 = running, 1 = finished training, 2 = terminated prematurely
tree_depth += 1
tree_state = _one_step(
schema_madlib, training_table_name,
cat_features, con_features, boolean_cats, bins,
n_bins, tree_state, weights, dep_var_str,
min_split, min_bucket, max_depth, filter_dep,
dep_n_levels, subsample, n_random_features, max_n_surr, null_proxy)
plpy.notice("Completed training of level {0}".format(tree_depth))
return tree_state
# ------------------------------------------------------------------------------
def _tree_train_grps_using_bins(
schema_madlib, bins, training_table_name,
cat_features, con_features, boolean_cats, n_bins, weights,
grouping_cols, grouping_array_str,
dep_var_str, min_split, min_bucket, max_depth, filter_dep,
dep_n_levels, is_classification, split_criterion,
cat_features_info_table,
subsample=False, n_random_features=1, tree_terminated=None,
max_n_surr=0, null_proxy=None,
**kwargs):
"""Trains a tree with grouping columns included """
# Iterations for training the tree
initialized_tree_state = plpy.execute(
"""
SELECT
{schema_madlib}._initialize_decision_tree(
{is_regression_tree},
'{split_criterion}'::text,
{dep_n_levels}::smallint,
{max_n_surr}::smallint) AS tree_state,
0 AS finished
""".format(schema_madlib=schema_madlib,
is_regression_tree=not is_classification,
split_criterion=split_criterion,
dep_n_levels=dep_n_levels,
max_n_surr=max_n_surr))[0]
tree_states = []
for key in bins['grp_key_con']:
group_copy = dict(initialized_tree_state, grp_key=key)
tree_states.append(group_copy)
# The following is only used in random forest:
# If a group already has a terminated tree (not finished
# properly), then we do not continue to compute more trees
# for that specific group of data.
if tree_terminated is not None:
for item in tree_states:
if item['grp_key'] in tree_terminated and \
tree_terminated[item['grp_key']] == 2: # terminated
item['finished'] = 2 # won't continue in _one_step_for_grps
plpy.notice("Started tree building for all groups")
level = 0
while not all(t['finished'] for t in tree_states):
tree_states = _one_step_for_grps(
schema_madlib, training_table_name, cat_features,
con_features, boolean_cats, bins, n_bins,
tree_states, weights, grouping_cols,
grouping_array_str, dep_var_str, min_split, min_bucket,
max_depth, filter_dep, dep_n_levels, subsample,
n_random_features, cat_features_info_table,
max_n_surr, null_proxy)
level += 1
plpy.notice("Finished training for level " + str(level))
return tree_states
# ------------------------------------------------------------
def _tree_error(schema_madlib, source_table, dependent_varname,
prediction_table, pred_dep_name, id_col_name, grouping_cols,
output_table, is_classification,
use_existing_tables=False, k=0, **kwargs):
with MinWarning("warning"):
if use_existing_tables and table_exists(output_table):
# plpy.execute("truncate " + output_table)
header = "INSERT INTO " + output_table + " "
else:
header = "CREATE TABLE " + output_table + " AS "
if is_classification:
error_func = """
1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} =
{source_table}.{dependent_varname})
THEN 0
ELSE 1
END) / count(*)
""".format(**locals())
else:
error_func = """
sqrt(avg(({prediction_table}.{pred_dep_name} -
{source_table}.{dependent_varname})^2
)
)
""".format(**locals())
grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
grouping_col_str = '' if not grouping_cols else grouping_cols + ','
sql = header + """
SELECT
{grouping_col_str}
{error_func} as cv_error,
{k} as k
FROM {prediction_table}, {source_table}
WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
{grouping_str}
""".format(**locals())
plpy.execute(sql)
# ------------------------------------------------------------
def tree_train_help_message(schema_madlib, message, **kwargs):
""" Help message for Decision Tree
"""
if not message:
help_string = """
------------------------------------------------------------
SUMMARY
------------------------------------------------------------
Functionality: Decision Tree
Decision trees use a tree-based predictive model to
predict the value of a target variable based on several input variables.
For more details on the function usage:
SELECT {schema_madlib}.tree_train('usage');
"""
elif message.lower().strip() in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------
USAGE
------------------------------------------------------------
SELECT {schema_madlib}.tree_train(
'training_table', -- Data table name
'output_table', -- Table name to store the tree model
'id_col_name', -- Row ID, used in tree_predict
'dependent_variable', -- The column to fit
'list_of_features', -- Comma separated column names to be
used as the predictors, can be '*'
to include all columns except the
dependent_variable
'features_to_exclude', -- Comma separated column names to be
excluded if list_of_features is '*'
'split_criterion', -- How to split a node, options are
'gini', 'misclassification' and
'entropy' for classification, and
'mse' for regression.
'grouping_cols', -- Comma separated column names used to
group the data. A decision tree model
will be created for each group. Default
is NULL
'weights', -- A Column name containing weights for
each observation. Default is NULL
max_depth, -- Maximum depth of any node, default is 7
min_split, -- Minimum number of observations that must
exist in a node for a split to be
attemped, default is 20
min_bucket, -- Minimum number of observations in any
terminal node, default is min_split/3
n_bins, -- Number of bins to find possible node
split threshold values for continuous
variables, default is 20 (Must be greater than 1)
pruning_params, -- A comma-separated text containing
key=value pairs of parameters for pruning.
Parameters accepted:
'cp' - complexity parameter with default=0.01,
'n_folds' - number of cross-validation folds
with default value of 0 (= no cross-validation)
null_handling_params, -- A comma-separated text containing
key=value pairs of parameters for handling NULL values.
Parameters accepted:
'max_surrogates' - Maximum number of surrogates to
compute for each split
'null_as_category' - Boolean to indicate if
NULL should be treated as a special category
verbose -- Boolean, whether to print more info, default is False
);
------------------------------------------------------------
OUTPUT
------------------------------------------------------------
The output table ('output_table' above) has the following columns (quoted items
are of type TEXT):
<grouping columns> -- Grouping columns, only present when
'grouping_cols' is not NULL or ''
tree -- The decision tree model as a binary string
cat_levels_in_text -- Distinct levels (casted to text) of all
categorical variables combined in a single array
cat_n_levels -- Number of distinct levels of all categorical variables
tree_depth -- Number of levels in the tree (root has level 0)
pruning_cp -- The cost-complexity parameter used for pruning
the trained tree(s). This would be different
from the input cp value if cross-validation is used.
The output summary table ('output_table_summary') has the following columns:
'method' -- Method name: 'tree_train'
'source_table' -- Data table name
'model_table' -- Tree model table name
'id_col_name' -- Name of the 'id' column
is_classification -- Boolean value indicating if tree is classification or regression
'dependent_varname' -- Response variable column name
'independent_varnames' -- Comma-separated feature column names
'cat_features' -- Comma-separated column names of categorical variables
'con_features' -- Comma-separated column names of continuous variables
'grouping_cols' -- Grouping column names
num_all_groups -- Number of groups
num_failed_groups -- Number of groups for which training failed
total_rows_processed -- Number of rows used in the model training
total_rows_skipped -- Number of rows skipped because NULL values
dependent_var_levels -- For classification, the distinct levels of
the dependent variable
dependent_var_type -- The type of dependent variable
input_cp -- The complexity parameter (cp) used for pruning the
trained tree(s) (before cross-validation is run)
independent_var_types -- The types of independent variables, comma-separated
k -- Number of folds (NULL if not using cross validation)
null_proxy -- String used as replacement for NULL values
(NULL if null_as_category = False)
"""
else:
help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------
def tree_predict_help_message(schema_madlib, message, **kwargs):
""" Help message for Decision Tree predict
"""
if not message:
help_string = """
------------------------------------------------------------
SUMMARY
------------------------------------------------------------
Functionality: Decision Tree Prediction
Prediction for a decision tree (trained using {schema_madlib}.tree_train) can
be performed on a new data table.
For more details on the function usage:
SELECT {schema_madlib}.tree_predict('usage');
For an example on using this function:
SELECT {schema_madlib}.tree_predict('example');
"""
elif message.lower().strip() in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------
USAGE
------------------------------------------------------------
SELECT {schema_madlib}.tree_predict(
'tree_model', -- Model table name (output of tree_train)
'new_data_table', -- Prediction source table
'output_table', -- Table name to store the predictions
'type' -- Type of prediction output
);
Note: The 'new_data_table' should have the same 'id_col_name' column as used
in the training function. This is used to corelate the prediction data row with
the actual prediction in the output table.
------------------------------------------------------------
OUTPUT
------------------------------------------------------------
The output table ('output_table' above) has the '<id_col_name>' column giving
the 'id' for each prediction and the prediction columns for the response
variable (also called as dependent variable).
If prediction type = 'response', then the table has a single column with the
prediction value of the response. The type of this column depends on the type
of the response variable used during training.
If prediction type = 'prob', then the table has multiple columns, one for each
possible value of the response variable. The columns are labeled as
'estimated_prob_<dep value>', where <dep value> represents for each value
of the response.
"""
else:
help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------