blob: 12265914ab0e591976cacdd165e95044e305ac34 [file] [log] [blame]
# coding=utf-8
"""
@file random_forest.py_in
@brief Random Forest: Driver functions
@namespace random_forest
"""
import plpy
from math import sqrt
from utilities.control import MinWarning
from utilities.control import EnableOptimizer
from utilities.control import EnableHashagg
from utilities.validate_args import get_cols_and_types
from utilities.validate_args import is_var_valid
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.utilities import _assert
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import split_quoted_delimited_str
from utilities.utilities import extract_keyvalue_params
from decision_tree import _tree_train_using_bins
from decision_tree import _tree_train_grps_using_bins
from decision_tree import _get_bins
from decision_tree import _get_bins_grps
from decision_tree import _get_features_to_use
from decision_tree import _get_dep_type
from decision_tree import _is_dep_categorical
from decision_tree import _get_n_and_deplist
from decision_tree import _classify_features
from decision_tree import _get_filter_str
from decision_tree import _dict_get_quoted
from decision_tree import _get_display_header
from decision_tree import get_feature_str
# ------------------------------------------------------------
def forest_train_help_message(schema_madlib, message, **kwargs):
""" Help message for Random Forest
"""
if not message:
help_string = """
------------------------------------------------------------
SUMMARY
------------------------------------------------------------
Functionality: Random Forest
Random forests use a forest-based predictive model to
predict the value of a target variable based on several input variables.
For more details on the function usage:
SELECT {schema_madlib}.forest_train('usage');
For an example on using this function:
SELECT {schema_madlib}.forest_train('example');
"""
elif message.lower().strip() in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------
USAGE
------------------------------------------------------------
SELECT {schema_madlib}.forest_train(
'training_table', -- Data table name
'output_table', -- Table name to store the forest model
'id_col_name', -- Row ID, used in forest_predict
'dependent_variable', -- The column to fit
'list_of_features', -- Comma separated column names to be
used as the predictors, can be '*'
to include all columns except the
dependent_variable
'features_to_exclude', -- Comma separated column names to be
excluded if list_of_features is '*'
'grouping_cols', -- Comma separated column names used to
group the data. A random forest model
will be created for each group. Default
is NULL
num_trees, -- Integer, default: 100. Maximum number of trees
to grow in the Random Forest model. Actual
number of trees grown may be slighlty different.
num_random_features, -- Integer, default: sqrt(n) if classification tree,
otherwise n/3. Number of features to randomly
select at each split.
importance, -- Boolean, whether to calculate variable importance,
default is True
num_permutations -- Number of times to permute each feature value while
calculating variable importance, default is 1.
max_tree_depth, -- Maximum depth of any node, default is 10
min_split, -- Minimum number of observations that must
exist in a node for a split to be
attemped, default is 20
min_bucket, -- Minimum number of observations in any
terminal node, default is min_split/3
num_splits, -- Number of bins to find possible node
split threshold values for continuous
variables, default is 100 (Must be greater than 1)
'surrogate_params', -- Text, Comma-separated string of key-value pairs
controlling the behavior of surrogate splits for
each node in a tree. 'max_surrogates=n', where n
is an positive integer with the default 0.
verbose, -- Boolean, whether to print more info,
default is False
sample_ratio -- Double precision, in the range of (0, 1], default: 1
If sample_ratio is less than 1, a bootstrap sample
size smaller than the data table is expected to be
used for training each tree in the forest.
);
------------------------------------------------------------
OUTPUT
------------------------------------------------------------
The output table ('output_table' above) has the following columns (
quoted items are of text type.):
gid -- Integer. group id that uniquely identifies
a set of grouping column values.
sample_id -- Integer. id of the bootstrap sample that
this tree is a part of.
tree -- bytea8. trained tree model stored in binary
format.
The output summary table ('output_table_summary') has the following
columns:
'method' -- Method name: 'forest_train'
is_classification -- boolean. True if it is a classification model.
'source_table' -- Data table name
'model_table' -- Tree model table name
'id_col_name' -- The ID column name
'dependent_varname' -- Response variable column name
'independent_varnames' -- Comma-separated feature column names
'cat_features' -- Comma-separated column names of categorical variables
'con_features' -- Comma-separated column names of continuous variables
'grouping_cols' -- Grouping column names
num_trees -- Number of trees grown by the model
num_random_features -- Number of features randomly selected for each split.
max_tree_depth -- Maximum depth of any node.
min_split -- Minimum number of observations that must
exist in a node for a split to be
attemped.
min_bucket -- Minimum number of observations in any
terminal node.
num_splits -- Number of bins to find possible node
split threshold values for continuous
variables.
verbose -- Boolean, whether to print more info.
importance -- Boolean, whether to calculate variable importance,
num_permutations -- Number of times to permute each feature value while
calculating variable importance
num_all_groups -- Int. Number of groups during forest training.
num_failed_groups -- Number of groups for which training failed
total_rows_processed -- Number of rows used in the model training
total_rows_skipped -- Number of rows skipped because NULL values
dependent_var_levels -- For classification, the distinct levels of
the dependent variable
dependent_var_type -- The type of dependent variable
A third table contains the information about the grouping ('output_table_group') and
it has the following columns:
gid -- Integer. group id that uniquely identifies a set
of grouping column values.
<...> -- Grouping columns, if provided in input. Same type as
in the training data table. This could be multiple
columns depending on the grouping_cols input.
success -- Boolean, indicator of the success of the group.
cat_levels_in_text -- text[]. Ordered levels of categorical variables.
cat_n_levels -- integer[]. Number of levels for each categorical variable.
oob_error -- double precision. Out-of-bag error for the random forest model.
cat_var_importance -- double precision[]. Variable importance for categorical
features. The order corresponds to the order of the
variables as found in cat_features in <output_table>_summary.
con_var_importance -- double precision[]. Variable importance for continuous
features. The order corresponds to the order of the
variables as found in con_features in <model_table>_summary.
"""
elif message.lower().strip() in ['example', 'examples']:
help_string = """
------------------------------------------------------------
EXAMPLE
------------------------------------------------------------
DROP TABLE IF EXISTS dt_golf;
CREATE TABLE dt_golf (
id integer NOT NULL,
"OUTLOOK" text,
temperature double precision,
humidity double precision,
windy text,
class text
);
INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,windy,class) VALUES
(1, 'sunny', 85, 85, 'false', 'Don''t Play'),
(2, 'sunny', 80, 90, 'true', 'Don''t Play'),
(3, 'overcast', 83, 78, 'false', 'Play'),
(4, 'rain', 70, 96, 'false', 'Play'),
(5, 'rain', 68, 80, 'false', 'Play'),
(6, 'rain', 65, 70, 'true', 'Don''t Play'),
(7, 'overcast', 64, 65, 'true', 'Play'),
(8, 'sunny', 72, 95, 'false', 'Don''t Play'),
(9, 'sunny', 69, 70, 'false', 'Play'),
(10, 'rain', 75, 80, 'false', 'Play'),
(11, 'sunny', 75, 70, 'true', 'Play'),
(12, 'overcast', 72, 90, 'true', 'Play'),
(13, 'overcast', 81, 75, 'false', 'Play'),
(14, 'rain', 71, 80, 'true', 'Don''t Play');
DROP TABLE IF EXISTS train_output, train_output_group, train_output_summary;
SELECT madlib.forest_train('dt_golf', -- source table
'train_output', -- output model table
'id', -- id column
'class', -- response
'"OUTLOOK", temperature, humidity, windy', -- features
NULL, -- exclude columns
NULL, -- grouping columns
20::integer, -- number of trees
2::integer, -- number of random features
TRUE::boolean, -- variable importance
1::integer, -- num_permutations
8::integer, -- max depth
3::integer, -- min split
1::integer, -- min bucket
10::integer -- number of splits per continuous variable
);
SELECT madlib.get_tree('train_output',1,2,FALSE);
"""
else:
help_string = "No such option. Use {schema_madlib}.forest_train('usage')"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------
def forest_train(
schema_madlib, training_table_name, output_table_name, id_col_name,
dependent_variable, list_of_features, list_of_features_to_exclude,
grouping_cols, num_trees, num_random_features,
importance, num_permutations, max_tree_depth,
min_split, min_bucket, num_bins,
surrogate_params, verbose=False, sample_ratio=None, **kwargs):
""" Random forest main training function
Args:
@param schema_madlib: str, MADlib schema name
@param training_table_name: str, source table name
@param output_table_name: str, model table name
@param id_col_name: str, id column name to uniquely identify each row
@param dependent_variable: str, dependent variable column name
@param list_of_features: str, Comma-separated list of feature column names,
can also be '*' implying all columns
except dependent_variable
@param list_of_features_to_exclude: str, features to exclude if '*' is used
@param grouping_cols: str, List of grouping columns to group the data
@param num_trees: int, Number of trees in the forest
@param num_random_features: int, Number of random features used in spliting nodes
@param importance: boolean, Whether or not to calculate variable importance
@param num_permutations: int, Number of times to permute each feature value
during calculation of variable importance
@param max_tree_depth: int, Maximum depth of each tree
@param min_split: int, Minimum tuples in a node before splitting it
@param min_bucket: int, Minimum tuples in each child before splitting a node
@param num_bins: int, Number of bins for quantizing a continuous variables
@param verbose: str, Verbosity of output messages
@param sample_ratio: float, subsampling ratio for generating src_view
"""
msg_level = "'notice'" if verbose else "'warning'"
with MinWarning(msg_level):
with EnableOptimizer(False):
# we disable optimizer (ORCA) for platforms that use it
# since ORCA doesn't provide an easy way to disable hashagg
with EnableHashagg(False):
# we disable hashagg since large number of groups could
# result in excessive memory usage.
##################################################################
#### set default values
if grouping_cols is not None and grouping_cols.strip() == '':
grouping_cols = None
num_trees = 100 if num_trees is None else num_trees
max_tree_depth = 10 if max_tree_depth is None else max_tree_depth
min_split = 20 if min_split is None and min_bucket is None else min_split
min_bucket = min_split // 3 if not min_bucket else min_bucket
min_split = min_bucket * 3 if not min_split else min_split
num_bins = 100 if num_bins is None else num_bins
sample_ratio = 1 if sample_ratio is None else sample_ratio
surrogate_param_dict = extract_keyvalue_params(
surrogate_params,
dict(max_surrogates=int), # type of variable
dict(max_surrogates=0)) # default value of variable
max_n_surr = surrogate_param_dict['max_surrogates']
_assert(max_n_surr >= 0,
"Maximum number of surrogates ({0}) should be non-negative".
format(max_n_surr))
##################################################################
# validate arguments
_forest_validate_args(training_table_name, output_table_name, id_col_name,
list_of_features, dependent_variable,
list_of_features_to_exclude, grouping_cols,
num_trees, num_random_features,
num_permutations, max_tree_depth,
min_split, min_bucket, num_bins, sample_ratio)
##################################################################
# preprocess arguments
# expand "*" syntax and exclude some features
features = _get_features_to_use(schema_madlib,
training_table_name,
list_of_features,
list_of_features_to_exclude,
id_col_name,
'1', dependent_variable,
grouping_cols)
_assert(bool(features),
"Random forest error: No feature is selected for the model.")
is_classification, is_bool = _is_dep_categorical(
training_table_name, dependent_variable)
split_criterion = 'gini' if is_classification else 'mse'
if num_random_features is None:
n_all_features = len(features)
num_random_features = (sqrt(n_all_features) if is_classification
else n_all_features / 3)
_assert(num_random_features <= len(features),
"Random forest error: Number of features to be selected "
"is more than the actual number of features.")
all_cols_types = dict(get_cols_and_types(training_table_name))
cat_features, ordered_cat_features, con_features, boolean_cats = \
_classify_features(all_cols_types, features)
filter_null = _get_filter_str(dependent_variable, grouping_cols)
# the total number of records
n_all_rows = plpy.execute("SELECT count(*) FROM {0}".
format(training_table_name))[0]['count']
if is_classification:
# For classifications, we also need to map dependent_variable to integers
n_rows, dep_list = _get_n_and_deplist(training_table_name,
dependent_variable,
filter_null)
_assert(n_rows > 0,
"Random forest error: There should be at least one "
"data point for each class where all features are non NULL")
dep_list.sort()
dep_col_str = ("CASE WHEN " + dependent_variable +
" THEN 'True' ELSE 'False' END") if is_bool else dependent_variable
dep = ("(CASE " +
"\n ".join([
"WHEN ({dep_col})::text = $${c}$$ THEN {i}".format(
dep_col=dep_col_str, c=c, i=i)
for i, c in enumerate(dep_list)]) +
"\nEND)")
dep_n_levels = len(dep_list)
else:
n_rows = plpy.execute(
"SELECT count(*) FROM {source_table} where {filter_null}".
format(source_table=training_table_name,
filter_null=filter_null))[0]['count']
dep = dependent_variable
dep_n_levels = 1
dep_list = None
# a table that maps gid/grp_key to actual columns
grp_key_to_grp_cols = unique_string()
# create the above table and perform binning
if grouping_cols is None:
sql_grp_key_to_grp_cols = """
CREATE TABLE {grp_key_to_grp_cols} AS
SELECT ''::text AS grp_key, 1 AS gid
""".format(**locals())
plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols)
plpy.execute(sql_grp_key_to_grp_cols)
# find the bins, one dict containing two arrays: categorical
# bins, and continuous bins
num_groups = 1
bins = _get_bins(schema_madlib, training_table_name,
cat_features, ordered_cat_features,
con_features, num_bins, dep,
boolean_cats, n_rows, is_classification,
dep_n_levels, filter_null)
# some features may be dropped because they have only one value
cat_features = bins['cat_features']
bins['grp_key_cat'] = ['']
else:
grouping_cols_list = [col.strip() for col in grouping_cols.split(',')]
grouping_cols_and_types = [(col, _dict_get_quoted(all_cols_types, col))
for col in grouping_cols_list]
grouping_array_str = (
"array_to_string(array[" +
','.join("(case when " + col + " then 'True' else 'False' end)::text"
if col_type == 'boolean' else '(' + col + ')::text'
for col, col_type in grouping_cols_and_types) +
"], ',')")
grouping_cols_str = ('' if grouping_cols is None
else grouping_cols + ",")
sql_grp_key_to_grp_cols = """
CREATE TABLE {grp_key_to_grp_cols} AS
SELECT
{grouping_cols},
{grouping_array_str} AS grp_key,
(row_number() over ())::integer AS gid
FROM {training_table_name}
GROUP BY {grouping_cols}
""".format(**locals())
plpy.notice("sql_grp_key_to_grp_cols:\n" + sql_grp_key_to_grp_cols)
plpy.execute(sql_grp_key_to_grp_cols)
# find bins
num_groups = plpy.execute("""
SELECT count(*) FROM {grp_key_to_grp_cols}
""".format(**locals()))[0]['count']
plpy.notice("Analyzing data to compute split boundaries for variables")
bins = _get_bins_grps(schema_madlib, training_table_name,
cat_features, ordered_cat_features,
con_features, num_bins, dep,
boolean_cats, grouping_cols,
grouping_array_str, n_rows,
is_classification, dep_n_levels, filter_null)
cat_features = bins['cat_features']
# a table for converting cat_features to integers
cat_features_info_table = unique_string()
sql_cat_features_info = """
CREATE TEMP TABLE {cat_features_info_table} AS
SELECT
gid,
cat_n_levels,
cat_levels_in_text
FROM
(
SELECT *
FROM {schema_madlib}._gen_cat_levels_set($1, $2, $3, $4)
) subq
JOIN
{grp_key_to_grp_cols}
USING (grp_key)
""".format(**locals())
plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info)
plan_cat_features_info = plpy.prepare(
sql_cat_features_info, ['text[]', 'integer[]', 'integer', 'text[]'])
plpy.execute(plan_cat_features_info, [
bins['grp_key_cat'],
bins['cat_n'],
len(cat_features),
bins['cat_origin']])
con_splits_table = unique_string()
_create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, grp_key_to_grp_cols, bins)
##################################################################
# create views and tables for training (growing) of trees
# store the prediction for all oob samples
# for classification, the prediction is of integer type here
oob_prediction_table = unique_string()
sql_create_oob_prediction_table = """
CREATE TEMP TABLE {oob_prediction_table} AS
SELECT
{id_col_name},
1 AS sample_id,
1 AS gid,
{dep} AS dep,
{dep} AS oob_prediction,
ARRAY[1.0]::float8[] AS cat_imp_score,
ARRAY[1.0]::float8[] AS con_imp_score
FROM {training_table_name}
LIMIT 0
""".format(**locals())
plpy.notice("sql_create_oob_prediction_table:\n" + sql_create_oob_prediction_table)
plpy.execute(sql_create_oob_prediction_table)
# to store poisson count defining bootstrap sample
training_pois_cnt_table = unique_string()
subsample_random_column = unique_string()
sql_create_training_pois_cnt = """
CREATE TEMP TABLE {training_pois_cnt_table} AS
SELECT
*,
1.::double precision,
{schema_madlib}.poisson_random(1) AS poisson_count
FROM {training_table_name}
LIMIT 0
""".format(**locals())
plpy.notice("sql_create_training_pois_cnt:\n" + sql_create_training_pois_cnt)
plpy.execute(sql_create_training_pois_cnt)
# views dependent on current bootstrap sample
src_view = unique_string()
sql_create_src_view = """
CREATE VIEW {src_view} AS
SELECT *
FROM {training_pois_cnt_table}
WHERE poisson_count != 0
""".format(**locals())
plpy.notice("sql_create_src_view:\n" + sql_create_src_view)
plpy.execute(sql_create_src_view)
oob_view = unique_string()
sql_create_oob_view = """
CREATE VIEW {oob_view} AS
SELECT *
FROM {training_pois_cnt_table}
WHERE poisson_count = 0
""".format(**locals())
plpy.notice("sql_create_oob_view:\n" + sql_create_oob_view)
plpy.execute(sql_create_oob_view)
_create_empty_result_table(schema_madlib, output_table_name)
##################################################################
# training random forest
tree_terminated = None
for sample_id in range(1, num_trees + 1):
if 1 - sample_ratio < 1e-6:
random_sample_expr = "0.::double precision"
else:
random_sample_expr = "random()"
sql_refresh_training_pois_cnt = """
TRUNCATE TABLE {training_pois_cnt_table} CASCADE;
INSERT INTO {training_pois_cnt_table}
SELECT
*,
{schema_madlib}.poisson_random(1) AS poisson_count
FROM
(
SELECT
*,
{random_sample_expr} AS {subsample_random_column}
FROM {training_table_name}
) subq
WHERE {subsample_random_column} < {sample_ratio}
""".format(**locals())
plpy.notice("sql_refresh_training_pois_cnt:\n" + sql_refresh_training_pois_cnt)
plpy.execute(sql_refresh_training_pois_cnt)
if verbose:
tup_cnt_in_view = plpy.execute("""
SELECT
count(*) AS c,
sum(poisson_count) AS s
FROM {src_view}
""".format(**locals()))[0]
src_cnt = tup_cnt_in_view['c']
dup_cnt = tup_cnt_in_view['s']
oob_cnt = plpy.execute("""
SELECT count(*) AS c FROM {oob_view}
""".format(**locals()))[0]['c']
plpy.notice("""
src_cnt: {src_cnt},
oob_cnt: {oob_cnt},
dup_cnt: {dup_cnt}.
""".format(**locals()))
if grouping_cols is None: # non-grouping case
tree_state = _tree_train_using_bins(
schema_madlib, bins, src_view, cat_features, con_features,
boolean_cats, num_bins, 'poisson_count', dep, min_split,
min_bucket, max_tree_depth, filter_null, dep_n_levels,
is_classification, split_criterion, True,
num_random_features, max_n_surr)
tree_states = [dict(tree_state=tree_state['tree_state'],
grp_key='')]
tree_terminated = {'': tree_state['finished']}
else:
tree_states = _tree_train_grps_using_bins(
schema_madlib, bins, src_view, cat_features, con_features,
boolean_cats, num_bins, 'poisson_count', grouping_cols,
grouping_array_str, dep, min_split, min_bucket,
max_tree_depth, filter_null, dep_n_levels,
is_classification, split_criterion, True,
num_random_features, tree_terminated=tree_terminated,
max_n_surr=max_n_surr)
# If a tree for a group is terminated (not finished properly),
# then we do not need to compute other trees, and can just
# stop calculating that group further.
if tree_terminated is None:
tree_terminated = dict((item['grp_key'], item['finished'])
for item in tree_states)
else:
for item in tree_states:
if item['grp_key'] not in tree_terminated:
tree_terminated[item['grp_key']] = item['finished']
elif item['finished'] == 2:
tree_terminated[item['grp_key']] = 2
_insert_into_result_table(
schema_madlib, tree_states, output_table_name,
grp_key_to_grp_cols, sample_id)
_calculate_oob_prediction(
schema_madlib, output_table_name, cat_features_info_table,
con_splits_table, oob_prediction_table, oob_view,
sample_id, id_col_name, cat_features, con_features,
boolean_cats, grouping_cols, grp_key_to_grp_cols, dep,
num_permutations, is_classification, importance, num_bins)
###################################################################
# evaluating and summerizing random forest
oob_error_table = unique_string()
_calculate_oob_error(schema_madlib, oob_prediction_table,
oob_error_table, id_col_name,
is_classification)
importance_table = unique_string()
sql_create_empty_imp_tbl = """
CREATE TEMP TABLE {importance_table}
(
gid integer,
cat_var_importance float8[],
con_var_importance float8[]
);
""".format(**locals())
plpy.notice("sql_create_empty_imp_tbl:\n"+sql_create_empty_imp_tbl)
plpy.execute(sql_create_empty_imp_tbl)
# we populate the importance_table only if variable importance is to be
# calculated, otherwise we use an empty table which will be used later
# for an outer join.
if importance:
_calculate_variable_importance(schema_madlib,
oob_prediction_table, is_classification,
importance_table, len(cat_features), len(con_features))
_create_group_table(schema_madlib, output_table_name,
oob_error_table, importance_table,
cat_features_info_table, grp_key_to_grp_cols,
grouping_cols, tree_terminated)
num_failed_groups = sum(1 for v in tree_terminated.values() if v != 1)
_create_summary_table(**locals())
sql_cleanup = """
DROP TABLE {training_pois_cnt_table} CASCADE;
DROP TABLE {oob_prediction_table} CASCADE;
DROP TABLE {importance_table} CASCADE;
DROP TABLE {oob_error_table} CASCADE;
DROP TABLE {cat_features_info_table} CASCADE;
DROP TABLE {con_splits_table} CASCADE;
DROP TABLE {grp_key_to_grp_cols} CASCADE;
""".format(**locals())
plpy.notice("sql_cleanup:\n" + sql_cleanup)
plpy.execute(sql_cleanup)
return None
# ------------------------------------------------------------
def forest_predict(schema_madlib, model, source, output, pred_type='response',
**kwargs):
"""
Args:
@param schema_madlib: str, Name of MADlib schema
@param model: str, Name of table containing the forest model
@param source: str, Name of table containing prediction data
@param output: str, Name of table to output the results
@param pred_type: str, The type of output required:
'response' gives the actual response values,
'prob' gives the probability of the classes in a
classification model.
For regression model, only type='response' is defined.
Returns:
None
Side effect:
Creates an output table containing the prediction for given source table
Throws:
None
"""
pred_type = 'response' if pred_type is None or pred_type == '' else pred_type
_validate_predict(model, source, output, pred_type)
model_summary = add_postfix(model, "_summary")
model_group = add_postfix(model, "_group")
# obtain the cat_features and con_features from model table
summary_elements = plpy.execute("SELECT * FROM {0}".format(model_summary))[0]
cat_features = split_quoted_delimited_str(summary_elements["cat_features"])
con_features = split_quoted_delimited_str(summary_elements["con_features"])
id_col_name = summary_elements["id_col_name"]
grouping_cols = summary_elements["grouping_cols"]
dep_varname = summary_elements["dependent_varname"]
dep_levels = summary_elements["dependent_var_levels"]
is_classification = summary_elements["is_classification"]
dep_type = summary_elements['dependent_var_type']
# pred_type='prob' is allowed only for classification
_assert(is_classification or pred_type == 'response',
"Random forest error: pred_type cannot be 'prob' for regression model.")
# find which columns are of type boolean
boolean_cats = set([key for key, value in get_cols_and_types(source)
if value == 'boolean'])
cat_features_str, con_features_str = get_feature_str(
schema_madlib, boolean_cats, cat_features, con_features,
"cat_levels_in_text", "cat_n_levels")
pred_name = ('"prob_{0}"' if pred_type == "prob" else
'"estimated_{0}"').format(dep_varname.replace('"', '').strip())
join_str = "," if grouping_cols is None else "JOIN"
using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")"
if not is_classification:
majority_pred_expression = "avg(aggregated_prediction)"
else:
majority_pred_expression = """($sql${{ {dep_levels} }}$sql$::varchar[])[
{schema_madlib}.mode(aggregated_prediction + 1)]::TEXT
""".format(**locals())
if dep_type.lower() == "boolean":
# some platforms don't have text to boolean cast. We manually check the string.
majority_pred_cast_str = ("(case {majority_pred_expression} when 'True' then "
"True else False end)::BOOLEAN as {pred_name}")
else:
majority_pred_cast_str = "{majority_pred_expression}::{dep_type} as {pred_name}"
majority_pred_cast_str = majority_pred_cast_str.format(**locals())
num_trees_grown = plpy.execute(
"SELECT count(distinct sample_id) FROM {model}"
.format(**locals()))[0]['count']
if pred_type == "response" or not is_classification:
sql_prediction = """
CREATE TABLE {output} AS
SELECT
{id_col_name},
{majority_pred_cast_str}
FROM
(
SELECT
{id_col_name},
{schema_madlib}._predict_dt_response(
tree,
{cat_features_str}::integer[],
{con_features_str}::double precision[]) AS aggregated_prediction
FROM
{source}
{join_str}
{model_group}
{using_str}
JOIN
{model}
USING (gid)
) prediction_agg
GROUP BY {id_col_name}
""".format(**locals())
else:
len_dep_levels = len(split_quoted_delimited_str(dep_levels))
normalized_majority_pred = unique_string()
score_format = ', \n'.join([
'{temp}[{j}] as "estimated_prob_{c}"'.
format(j=i+1, c=c.strip(' "'), temp=normalized_majority_pred)
for i, c in enumerate(split_quoted_delimited_str(dep_levels))])
sql_prediction = """
CREATE TABLE {output} AS
SELECT
{id_col_name},
{score_format}
FROM
(
SELECT
{id_col_name},
{schema_madlib}.discrete_distribution_agg(
prediction::integer,
1.,
{len_dep_levels}
)::double precision[]
AS {normalized_majority_pred}
FROM
(
SELECT
{id_col_name},
gid,
{schema_madlib}._predict_dt_response(
tree,
{cat_features_str}::integer[],
{con_features_str}::double precision[]) as prediction
FROM
{source}
{join_str}
{model_group}
{using_str}
JOIN
{model}
USING (gid)
) class_prediction_subq
GROUP BY gid, {id_col_name}
) subq
""".format(**locals())
with MinWarning('warning'):
with EnableOptimizer(False):
# we disable optimizer (ORCA) for platforms that use it
# since ORCA doesn't provide an easy way to disable hashagg
with EnableHashagg(False):
# we disable hashagg since large number of groups could
# result in excessive memory usage.
plpy.notice("sql_prediction:\n"+sql_prediction)
plpy.execute(sql_prediction)
return None
# ------------------------------------------------------------
def get_tree_surr(schema_madlib, model_table, gid, sample_id, **kwargs):
return get_tree(schema_madlib, model_table, gid, sample_id,
dot_format=False, disp_surr=True)
def get_tree(schema_madlib, model_table, gid, sample_id,
dot_format=True, verbose=False, disp_surr=False, **kwargs):
"""Random forest tree display function"""
_validate_get_tree(model_table, gid, sample_id)
if dot_format:
disp_surr = False # surrogates cannot be displayed in dot format
bytea8 = schema_madlib + '.bytea8'
model_table_summary = add_postfix(model_table, "_summary")
model_table_group = add_postfix(model_table, "_group")
summary = plpy.execute("SELECT * FROM {model_table_summary}".
format(model_table_summary=model_table_summary))[0]
dep_levels = summary["dependent_var_levels"]
dep_levels = [''] if not dep_levels else split_quoted_delimited_str(dep_levels)
table_name = summary["source_table"]
is_regression = not summary["is_classification"]
cat_features_str = split_quoted_delimited_str(summary['cat_features'])
con_features_str = split_quoted_delimited_str(summary['con_features'])
with MinWarning('warning'):
sql_tree_result = """
SELECT
tree,
cat_levels_in_text,
cat_n_levels
FROM
{model_table}
JOIN
{model_table_group}
USING (gid)
WHERE sample_id = {sample_id}
AND gid = {gid}
""".format(**locals())
plpy.notice("sql_tree_result:\n"+sql_tree_result)
tree_result = plpy.execute(sql_tree_result)
if not tree_result:
plpy.warning("no tree found by the given gid and sample_id, exiting...")
tree = tree_result[0]
if tree['cat_levels_in_text']:
cat_levels_in_text = tree['cat_levels_in_text']
cat_n_levels = tree['cat_n_levels']
else:
cat_levels_in_text = []
cat_n_levels = []
return_str_list = []
if not disp_surr:
return_str_list.append(_get_display_header(table_name, dep_levels,
is_regression, dot_format))
else:
return_str_list.append("""
-------------------------------------
Surrogates for internal nodes
-------------------------------------
""")
with MinWarning('warning'):
if disp_surr:
# Output only surrogate information for the internal nodes of tree
sql = """SELECT {0}._display_decision_tree_surrogate(
$1, $2, $3, $4, $5) as display_tree
""".format(schema_madlib)
# execute sql to get display string
sql_plan = plpy.prepare(sql, [bytea8,
'text[]', 'text[]', 'text[]',
'int[]'])
tree_display = plpy.execute(
sql_plan, [tree['tree'], cat_features_str, con_features_str,
cat_levels_in_text, cat_n_levels])[0]
else:
# Output the splits in each node of the tree
if dot_format:
sql_display = """
SELECT {0}._display_decision_tree(
$1, $2, $3, $4, $5, $6, '{1}', {2}
) as display_tree
""".format(schema_madlib, "", verbose)
else:
sql_display = """
SELECT {0}._display_text_decision_tree(
$1, $2, $3, $4, $5, $6
) as display_tree
""".format(schema_madlib)
plpy.notice("sql_display:\n"+sql_display)
plan_display = plpy.prepare(sql_display, [bytea8,
'text[]', 'text[]', 'text[]',
'int[]', 'text[]'])
tree_display = plpy.execute(
plan_display, [tree['tree'], cat_features_str, con_features_str,
cat_levels_in_text, cat_n_levels,
dep_levels])[0]
return_str_list.append(tree_display["display_tree"])
if dot_format:
return_str_list.append("} //---end of digraph--------- ")
return ("\n".join(return_str_list))
# ------------------------------------------------------------
def _calculate_oob_prediction(
schema_madlib, model_table, cat_features_info_table, con_splits_table,
oob_prediction_table, oob_view, sample_id, id_col_name, cat_features,
con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep,
num_permutations, is_classification, importance, num_bins):
"""Calculate predication for out-of-bag sample"""
cat_features_str, con_features_str = get_feature_str(
schema_madlib, boolean_cats, cat_features, con_features,
"cat_levels_in_text", "cat_n_levels")
join_str = "," if grouping_cols is None else "JOIN"
using_str = "" if grouping_cols is None else "USING (" + grouping_cols + ")"
oob_var_dist_view = unique_string()
if importance:
sql_create_oob_var_dist_view = """
CREATE VIEW {oob_var_dist_view} AS
SELECT
gid,
{schema_madlib}.vectorized_distribution_agg(
{schema_madlib}.array_scalar_add(
{cat_features_str}::integer[],
1 -- -1 shift to 0 for nulls
),
{schema_madlib}.array_scalar_add(
cat_n_levels,
1 -- -1 shift to 0 for nulls
)
) AS cat_feature_distributions,
{schema_madlib}.vectorized_distribution_agg(
{schema_madlib}.array_scalar_add(
{schema_madlib}._get_bin_indices_by_values(
{con_features_str}::double precision[],
con_splits
), -- bin_indices, -1 for NaN
1 -- -1 shift to 0 for nulls
),
{schema_madlib}.array_fill(
{schema_madlib}.array_of_float({n_con})::integer[],
({num_bins}+1)::integer
)
-- level of any continuous feature == num_bins
) AS con_index_distributions
FROM
{oob_view}
{join_str}
{grp_key_to_grp_cols}
{using_str}
JOIN
{cat_features_info_table}
USING (gid)
JOIN
{con_splits_table}
USING (gid)
GROUP BY gid
""".format(n_con=len(con_features), **locals())
else:
sql_create_oob_var_dist_view = """
CREATE VIEW {oob_var_dist_view} AS
SELECT
gid,
NULL::float8[] AS cat_feature_distributions,
NULL::float8[] AS con_index_distributions
FROM {cat_features_info_table}
""".format(**locals())
plpy.notice("sql_create_oob_var_dist_view : " + str(sql_create_oob_var_dist_view))
plpy.execute(sql_create_oob_var_dist_view)
sql_oob_predict = """
INSERT INTO {oob_prediction_table}
SELECT
{id_col_name},
sample_id,
gid,
{dep} AS dep,
{schema_madlib}._predict_dt_response(
tree,
{cat_features_str}::integer[],
{con_features_str}::double precision[]
) AS oob_prediction,
{schema_madlib}._rf_cat_imp_score(
tree,
{cat_features_str}::integer[],
{con_features_str}::double precision[],
cat_info.cat_n_levels,
{num_permutations},
{dep},
{is_classification},
cat_feature_distributions -- if distribution is NULL, returns NULL
) AS cat_imp_score,
{schema_madlib}._rf_con_imp_score(
tree,
{cat_features_str}::integer[],
{con_features_str}::double precision[],
con_info.con_splits,
{num_permutations},
{dep},
{is_classification},
con_index_distributions -- if distribution is NULL, returns NULL
) AS con_imp_score
FROM
{oob_view}
{join_str}
{grp_key_to_grp_cols}
{using_str}
JOIN
(
SELECT *
FROM {model_table}
WHERE sample_id = {sample_id}
) m
USING (gid)
JOIN
{cat_features_info_table} cat_info
USING (gid)
JOIN
{con_splits_table} con_info
USING (gid)
LEFT OUTER JOIN -- empty if variable importance is disabled
{oob_var_dist_view}
USING (gid)
""".format(**locals())
plpy.notice("sql_oob_predict : " + str(sql_oob_predict))
plpy.execute(sql_oob_predict)
# -------------------------------------------------------------------------
def _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols,
grp_key_to_grp_cols, bins):
bytea8 = schema_madlib + '.bytea8'
bytea8arr = schema_madlib + '.bytea8[]'
if grouping_cols is None:
sql_create_con_splits_table = """
CREATE TEMP TABLE {con_splits_table} AS
SELECT
1 AS gid,
$1 AS con_splits
""".format(con_splits_table=con_splits_table)
plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table)
sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table,
[bytea8])
plpy.execute(sql_create_con_splits_plan, [bins['con']])
else:
sql_create_con_splits_table = """
CREATE TABLE {con_splits_table} AS
SELECT
gid,
con_splits
FROM
{grp_key_to_grp_cols}
JOIN
(
SELECT
unnest($1) as grp_key,
unnest($2) as con_splits
) subq
USING (grp_key)
""".format(**locals())
plpy.notice("sql_create_con_splits_table:\n"+sql_create_con_splits_table)
sql_create_con_splits_plan = plpy.prepare(sql_create_con_splits_table,
['text[]', bytea8arr])
plpy.execute(sql_create_con_splits_plan,
[bins['grp_key_con'], bins['con']])
# ------------------------------------------------------------------------------
def _calculate_variable_importance(schema_madlib, oob_prediction_table,
is_classification, importance_table, n_cat, n_con):
if not is_classification:
score_expression = "-((oob_prediction - dep)^2)".format(**locals())
else:
score_expression = """
CASE WHEN dep = oob_prediction::integer
THEN 1.
ELSE 0.
END""".format(**locals())
sample_score_view = unique_string()
sql_create_sample_score_view = """
CREATE VIEW {sample_score_view} AS
SELECT
sample_id,
gid,
count(*) as size,
sum({score_expression}) as score,
{schema_madlib}.sum(cat_imp_score::FLOAT8[]) AS cat_imp_score,
{schema_madlib}.sum(con_imp_score::FLOAT8[]) AS con_imp_score
FROM
{oob_prediction_table}
GROUP BY sample_id, gid
""".format(**locals())
plpy.notice("sql_create_sample_score_view:\n" + sql_create_sample_score_view)
plpy.execute(sql_create_sample_score_view)
sql_create_importance_table = """
INSERT INTO {importance_table}
SELECT
gid,
{schema_madlib}.array_avg(
{schema_madlib}.array_scalar_mult(
{schema_madlib}.array_scalar_add(
cat_imp_score,
-score::float8
),
(-1. / size)::float8
),
FALSE -- not use absolute values
),
{schema_madlib}.array_avg(
{schema_madlib}.array_scalar_mult(
{schema_madlib}.array_scalar_add(
con_imp_score,
-score::float8
),
(-1. / size)::float8
),
FALSE -- not use absolute values
)
FROM
{sample_score_view}
GROUP BY gid
""".format(**locals())
plpy.notice("sql_create_importance_table:\n" + sql_create_importance_table)
plpy.execute(sql_create_importance_table)
# -------------------------------------------------------------------------
def _calculate_oob_error(schema_madlib, oob_prediction_table, oob_error_table,
id_col_name, is_classification):
"""Calculate out-of-bag error for oob samples"""
if not is_classification:
residual_expression = "(dep - forest_prediction)^2".format(**locals())
forest_prediction_agg = 'avg'
else:
residual_expression = """
CASE WHEN dep = forest_prediction::integer
THEN 0.
ELSE 1.
END""".format(**locals())
forest_prediction_agg = "{schema_madlib}.mode".format(**locals())
sql_compute_oob_error = """
CREATE TABLE {oob_error_table} AS
SELECT
gid,
avg({residual_expression}) AS oob_error
FROM
(
SELECT
gid,
dep,
{forest_prediction_agg}(oob_prediction) AS forest_prediction
FROM
{oob_prediction_table}
GROUP BY gid, {id_col_name}, dep
) prediction_subq
GROUP BY gid
""".format(**locals())
plpy.notice("sql_compute_oob_error : " + str(sql_compute_oob_error))
plpy.execute(sql_compute_oob_error)
# -------------------------------------------------------------------------
def _create_summary_table(**kwargs):
kwargs['features'] = ','.join(kwargs['cat_features'] + kwargs['con_features'])
if kwargs['dep_list']:
kwargs['dep_list_str'] = (
"$dep_list$" +
','.join('"{0}"'.format(str(dep)) for dep in kwargs['dep_list']) +
"$dep_list$")
else:
kwargs['dep_list_str'] = "NULL"
kwargs['indep_type'] = ', '.join(_dict_get_quoted(kwargs['all_cols_types'], col)
for col in kwargs['cat_features'] + kwargs['con_features'])
kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'],
kwargs['dependent_variable'])
kwargs['cat_features_str'] = ','.join(kwargs['cat_features'])
kwargs['con_features_str'] = ','.join(kwargs['con_features'])
if kwargs['grouping_cols']:
kwargs['grouping_cols_str'] = "'{grouping_cols}'".format(**kwargs)
else:
kwargs['grouping_cols_str'] = 'NULL'
kwargs['n_rows_skipped'] = kwargs['n_all_rows'] - kwargs['n_rows']
kwargs['output_table_summary'] = add_postfix(kwargs['output_table_name'], "_summary")
sql_create_summary_table = """
CREATE TABLE {output_table_summary} AS
SELECT
'forest_train'::text AS method,
'{is_classification}'::boolean AS is_classification,
'{training_table_name}'::text AS source_table,
'{output_table_name}'::text AS model_table,
'{id_col_name}'::text AS id_col_name,
'{dependent_variable}'::text AS dependent_varname,
'{features}'::text AS independent_varnames,
'{cat_features_str}'::text AS cat_features,
'{con_features_str}'::text AS con_features,
{grouping_cols_str}::text AS grouping_cols,
{num_trees}::integer AS num_trees,
{num_random_features}::integer AS num_random_features,
{max_tree_depth}::integer AS max_tree_depth,
{min_split}::integer AS min_split,
{min_bucket}::integer AS min_bucket,
{num_bins}::integer AS num_splits,
{verbose}::boolean AS verbose,
{importance}::boolean AS importance,
{num_permutations}::integer AS num_permutations,
{num_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{n_rows}::integer AS total_rows_processed,
{n_rows_skipped}::integer AS total_rows_skipped,
{dep_list_str}::text AS dependent_var_levels,
'{dep_type}'::text AS dependent_var_type,
'{indep_type}'::text AS independent_var_types
""".format(**kwargs)
plpy.notice("sql_create_summary_table:\n" + sql_create_summary_table)
plpy.execute(sql_create_summary_table)
# ------------------------------------------------------------
def _create_group_table(
schema_madlib, output_table_name, oob_error_table,
importance_table, cat_features_info_table, grp_key_to_grp_cols,
grouping_cols, tree_terminated):
""" Ceate the group table for random forest"""
grouping_cols_str = ('' if grouping_cols is None
else grouping_cols + ",")
group_table_name = add_postfix(output_table_name, "_group")
sql_create_group_table = """
CREATE TABLE {group_table_name} AS
SELECT
gid,
{grouping_cols_str}
grp_finished as success,
cat_n_levels,
cat_levels_in_text,
oob_error,
cat_var_importance,
con_var_importance
FROM
{oob_error_table}
JOIN
{grp_key_to_grp_cols}
USING (gid)
JOIN (
SELECT
unnest($1) as grp_key,
unnest($2) as grp_finished
) tree_terminated
USING (grp_key)
JOIN
{cat_features_info_table}
USING (gid)
LEFT OUTER JOIN
{importance_table}
USING (gid)
""".format(**locals())
plpy.notice("sql_create_group_table:\n" + sql_create_group_table)
plan_create_group_table = plpy.prepare(sql_create_group_table,
['text[]', 'boolean[]'])
plpy.execute(plan_create_group_table,
[tree_terminated.keys(),
[True if v == 1 else False for v in tree_terminated.values()]])
# -------------------------------------------------------------------------
def _create_empty_result_table(schema_madlib, output_table_name):
"""Create the result table for all trees in the forest"""
sql_create_empty_result_table = """
CREATE TABLE {output_table_name} (
gid integer,
sample_id integer,
tree {schema_madlib}.bytea8);
""".format(**locals())
plpy.notice("sql_create_empty_result_table:\n" + sql_create_empty_result_table)
plpy.execute(sql_create_empty_result_table)
# ------------------------------------------------------------
def _insert_into_result_table(schema_madlib, tree_states, output_table_name,
grp_key_to_grp_cols, sample_id):
"""Insert one tree to result table"""
sql = """
INSERT INTO {output_table_name}
SELECT
gid,
{sample_id} AS sample_id,
tree
FROM
(
SELECT
unnest($1) AS grp_key,
unnest($2) AS tree
) grp_key_to_tree
JOIN
{grp_key_to_grp_cols}
USING (grp_key)
""".format(**locals())
sql_plan = plpy.prepare(sql, ['text[]', '{0}.bytea8[]'.format(schema_madlib)])
plpy.execute(sql_plan, [
[tree_state['grp_key'] for tree_state in tree_states],
[tree_state['tree_state'] for tree_state in tree_states]])
# ------------------------------------------------------------
def _forest_validate_args(
training_table_name, output_table_name, id_col_name,
list_of_features, dependent_variable, list_of_features_to_exclude,
grouping_cols, num_trees, num_random_features, n_perm,
max_tree_depth, min_split, min_bucket, num_bins, sample_ratio):
""" Validate the arguments for the random forest training function"""
input_tbl_valid(training_table_name, 'Random forest')
cols_in_tbl_valid(training_table_name, [id_col_name], 'Random forest')
output_tbl_valid(output_table_name, 'Random forest')
output_tbl_valid(add_postfix(output_table_name, "_group"), 'Random forest')
output_tbl_valid(add_postfix(output_table_name, "_summary"), 'Random forest')
_assert(not (list_of_features is None or list_of_features.strip().lower() == ''),
"Random forest error: Features to include is empty.")
if list_of_features.strip() != '*':
_assert(is_var_valid(training_table_name, list_of_features),
"Random forest error: Invalid feature list ({0})".
format(list_of_features))
_assert(not (dependent_variable is None or dependent_variable.strip().lower() == ''),
"Random forest error: Dependent variable is empty.")
_assert(is_var_valid(training_table_name, dependent_variable),
"Random forest error: Invalid dependent variable ({0}).".
format(dependent_variable))
if grouping_cols is not None and grouping_cols.strip() != '':
_assert(is_var_valid(training_table_name, grouping_cols),
"Random forest error: Invalid grouping column argument.")
_assert(num_trees > 0, "Random forest error: num_trees must be positive.")
_assert(n_perm > 0, "Random forest error: num_permutations must be positive.")
if num_random_features is not None:
_assert(num_random_features > 0,
"Random forest error: num_random_features must be positive.")
_assert(max_tree_depth >= 0 and max_tree_depth <= 15,
"Random forest error: max_tree_depth must be non-negative and less than 16.")
_assert(min_split > 0, "Random forest error: min_split must be positive.")
_assert(min_bucket > 0, "Random forest error: min_bucket must be positive.")
_assert(num_bins > 1, "Random forest error: number of bins must be at least 2.")
_assert(sample_ratio > 0 and sample_ratio <= 1,
"Random forest error: sample_ratio must be in (0, 1].")
# ------------------------------------------------------------
def _validate_predict(model, source, output, pred_type):
"""Validations for input arguments"""
input_tbl_valid(model, 'Random forest')
cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest')
input_tbl_valid(add_postfix(model, "_group"), 'Random forest')
cols_in_tbl_valid(add_postfix(model, "_group"),
['gid', 'cat_n_levels', 'cat_levels_in_text'],
'Random forest')
input_tbl_valid(add_postfix(model, "_summary"), 'Random forest')
cols_in_tbl_valid(add_postfix(model, "_summary"),
["grouping_cols", "id_col_name", "dependent_varname",
"cat_features", "con_features", "is_classification"],
'Random forest')
input_tbl_valid(source, 'Random forest')
output_tbl_valid(output, 'Random forest')
_assert(pred_type in ('response', 'prob'),
"Random forest error: pred_type should be 'response' or 'prob'")
# ------------------------------------------------------------
def _validate_get_tree(model, gid, sample_id):
"""Validations for input arguments"""
input_tbl_valid(model, 'Random forest')
cols_in_tbl_valid(model, ['gid', 'sample_id', 'tree'], 'Random forest')
# ------------------------------------------------------------
def forest_predict_help_message(schema_madlib, message, **kwargs):
if not message:
help_string = """
------------------------------------------------------------
SUMMARY
------------------------------------------------------------
Functionality: Random Forest Prediction
Random forests use a forest-based predictive model to predict
the value of a target variable based on several input variables.
This is the function to make predictions using the model trained
by the function 'forest_train'.
For more details on the function usage:
SELECT {schema_madlib}.forest_predict('usage');
For an example on using this function:
SELECT {schema_madlib}.forest_predict('example');
"""
elif message.lower().strip() in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------
USAGE
------------------------------------------------------------
SELECT {schema_madlib}.forest_predict(
'forest_model', -- Model table name (output of forest_train)
'new_data_table', -- Source data table
'output_table', -- The name of the table storing the predictions
'type' -- Type of prediction output, 'response' or 'prob'
);
Note: The 'new_data_table' should have the same 'id_col_name' column as used
in the training function. This is used to corelate the prediction data row with
the actual prediction in the output table.
------------------------------------------------------------
OUTPUT
------------------------------------------------------------
The output table ('output_table' above) has the '<id_col_name>' column giving
the 'id' for each prediction and the prediction columns for the response
variable (also called as dependent variable).
If prediction type = 'response', then the table has a single column with the
prediction value of the response. The type of this column depends on the type
of the response variable used during training. The response value for regression
is the average prediction across all trees, and is the majority vote in
the case of classification.
If prediction type = 'prob', then the table has multiple columns, one for each
possible value of the response variable. The columns are labeled as
'estimated_prob_<dep value>', where <dep value> represents for each value
of the response. This is only for the classification models, and the value
is the fraction of votes in each category.
"""
elif message.lower().strip() in ['example', 'examples']:
help_string = """
------------------------------------------------------------
EXAMPLE
------------------------------------------------------------
-- Assuming the example of forest_train has been run
SELECT {schema_madlib}.forest_predict(
'forest_out',
'dummy_dt_src',
'forest_predict_out',
'response'
);
"""
else:
help_string = "No such option. Use {schema_madlib}.forest_predict('usage')"
return help_string.format(schema_madlib=schema_madlib)