blob: 0a7f7a5818ddfeb68504ee041adf764145fd7187 [file] [log] [blame]
/* ------------------------------------------------------
*
* @file decision_tree.cpp
*
* @brief Decision Tree models for regression and classification
*
*
*/ /* ----------------------------------------------------------------------- */
#include <iostream>
#include <sstream>
#include <vector>
#include <string>
#include <list>
#include <iterator>
#include <dbconnector/dbconnector.hpp>
#include "DT_proto.hpp"
#include "DT_impl.hpp"
#include "ConSplits.hpp"
#include <math.h> /* fabs */
#include "decision_tree.hpp"
namespace madlib {
// Use Eigen
using namespace dbal::eigen_integration;
namespace modules {
namespace recursive_partitioning {
enum { NOT_FINISHED=0, FINISHED, TERMINATED };
// ------------------------------------------------------------
// types
// ------------------------------------------------------------
typedef DecisionTree<MutableRootContainer> MutableTree;
typedef DecisionTree<RootContainer> Tree;
// Transition State for collecting statistics
typedef TreeAccumulator<RootContainer, Tree> LevelState;
typedef TreeAccumulator<MutableRootContainer, Tree> MutableLevelState;
// ------------------------------------------------------------
// functions
// ------------------------------------------------------------
AnyType
initialize_decision_tree::run(AnyType & args){
DecisionTree<MutableRootContainer> dt = DecisionTree<MutableRootContainer>();
bool is_regression_tree = args[0].getAs<bool>();
std::string impurity_func_str = args[1].getAs<std::string>();
uint16_t n_y_labels = args[2].getAs<uint16_t>();
uint16_t max_n_surr = args[3].getAs<uint16_t>();
if (is_regression_tree)
n_y_labels = REGRESS_N_STATS;
dt.rebind(1u, n_y_labels, max_n_surr, is_regression_tree);
dt.feature_indices(0) = dt.IN_PROCESS_LEAF;
dt.feature_thresholds(0) = 0;
dt.is_categorical(0) = 0;
if (max_n_surr > 0){
dt.surr_indices.setConstant(dt.SURR_NON_EXISTING);
dt.surr_thresholds.setConstant(0);
dt.surr_status.setConstant(0);
}
dt.predictions.row(0).setConstant(0);
dt.is_regression = is_regression_tree;
if (dt.is_regression){
dt.impurity_type = dt.MSE; // only MSE defined for regression
} else {
if ((impurity_func_str.compare("misclassification") == 0) ||
(impurity_func_str.compare("misclass") == 0))
dt.impurity_type = dt.MISCLASS;
else if ((impurity_func_str.compare("entropy") == 0) ||
(impurity_func_str.compare("cross-entropy") == 0))
dt.impurity_type = dt.ENTROPY;
else
dt.impurity_type = dt.GINI; // default impurity for classification
}
return dt.storage();
} // initialize_decision_tree
// ------------------------------------------------------------
////////////////////////////////////////////////////////////////
// Functions to capture leaf stats for picking primary splits //
////////////////////////////////////////////////////////////////
AnyType
compute_leaf_stats_transition::run(AnyType & args){
MutableLevelState state = args[0].getAs<MutableByteString>();
LevelState::tree_type dt = args[1].getAs<ByteString>();
// need to change this according to the calling function
if (state.terminated || args[4].isNull()) {
return args[0];
}
double response = args[4].getAs<double>();
double weight = args[5].getAs<double>();
if (weight < 0)
throw std::runtime_error("Negative weights present in the data");
NativeIntegerVector cat_features;
NativeColumnVector con_features;
try {
if (args[2].isNull()){
cat_features.rebind(this->allocateArray<int>(0));
}
else {
NativeIntegerVector xx_cat = args[2].getAs<NativeIntegerVector>();
cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size());
}
if (args[3].isNull()){
con_features.rebind(this->allocateArray<double>(0));
}
else {
NativeColumnVector xx_con = args[3].getAs<NativeColumnVector>();
con_features.rebind(xx_con.memoryHandle(), xx_con.size());
}
} catch (const ArrayWithNullException &e) {
return args[0];
}
// cat_levels.size = n_cat_features
NativeIntegerVector cat_levels;
if (args[6].isNull()){
cat_levels.rebind(this->allocateArray<int>(0));
}
else {
MutableNativeIntegerVector n_levels_per_cat =
args[6].getAs<MutableNativeIntegerVector>();
for (Index i = 0; i < n_levels_per_cat.size(); i++){
n_levels_per_cat[i] -= 1;
// ignore the last level since a split
// like 'var <= last level' would move all rows to
// a one side. Such a split will always be ignored
// when selecting the best split.
}
cat_levels.rebind(n_levels_per_cat.memoryHandle(), n_levels_per_cat.size());
}
// con_splits size = num_con_features x num_bins
// When num_con_features = 0, the input will be an empty string that is read
// as a ByteString
ConSplitsResult<RootContainer> splits_results = args[7].getAs<ByteString>();
// n_response_labels are the number of values the dependent variable takes
uint16_t n_response_labels = args[8].getAs<uint16_t>();
if (!dt.is_regression && n_response_labels <= 1){
// for classification, response should have at least two distinct values
throw std::runtime_error("Invalid response variable for a classification"
"tree. Should have "
"more than one distinct value");
}
if (state.empty()){
// To initialize the accumulator, first find which of the leaf nodes
// in current tree are actually reachable.
// The lookup vector maps the leaf node index in a (fictional) complete
// tree to the index in the actual tree.
ColumnVector leaf_feature_indices =
dt.feature_indices.tail(dt.feature_indices.size()/2 + 1).cast<double>();
IntegerVector leaf_node_lookup(leaf_feature_indices.size());
uint32_t n_leaves_not_finished = 0;
for (Index i=0; i < leaf_feature_indices.size(); i++){
if ((leaf_feature_indices(i) != dt.NODE_NON_EXISTING) &&
(leaf_feature_indices(i) != dt.FINISHED_LEAF)){
leaf_node_lookup(i) = static_cast<int>(n_leaves_not_finished);
n_leaves_not_finished++;
}
else{
leaf_node_lookup(i) = -1;
}
}
// see DT_proto.hpp for explanation on stats_per_split
uint16_t stats_per_split = dt.is_regression ?
REGRESS_N_STATS : static_cast<uint16_t>(n_response_labels + 1);
const bool weights_as_rows = args[9].getAs<bool>();
state.rebind(static_cast<uint16_t>(splits_results.con_splits.cols()),
static_cast<uint16_t>(cat_features.size()),
static_cast<uint16_t>(con_features.size()),
static_cast<uint32_t>(cat_levels.sum()),
static_cast<uint16_t>(dt.tree_depth),
stats_per_split,
weights_as_rows,
static_cast<uint32_t>(n_leaves_not_finished)
);
for (Index i=0; i < state.stats_lookup.size(); i++)
state.stats_lookup(i) = leaf_node_lookup(i);
// compute cumulative sum of the levels of the categorical variables
int current_sum = 0;
for (Index i=0; i < state.n_cat_features; ++i){
// Assuming that the levels of each categorical variable are ordered,
// create splits of the form 'A <= t', where A has N levels
// and t in [0, N-2].
// This split places all levels <= t on true node and
// others on false node. Checking till N-2 instead of N-1
// since at least 1 level should go to false node.
// Variable with just 1 level is maintained to ensure alignment,
// even though the variable will not be used as a split feature.
current_sum += cat_levels(i);
state.cat_levels_cumsum(i) = current_sum;
}
}
state << MutableLevelState::tuple_type(dt, cat_features, con_features,
response, weight,
cat_levels,
splits_results.con_splits);
return state.storage();
} // transition function
// ------------------------------------------------------------
AnyType
compute_leaf_stats_merge::run(AnyType & args){
MutableLevelState stateLeft = args[0].getAs<MutableByteString>();
LevelState stateRight = args[1].getAs<ByteString>();
if (stateLeft.empty()) {
return stateRight.storage();
}
if (!stateRight.empty()) {
stateLeft << stateRight;
}
return stateLeft.storage();
} // merge function
AnyType
dt_apply::run(AnyType & args){
MutableTree dt = args[0].getAs<MutableByteString>();
LevelState curr_level = args[1].getAs<ByteString>();
// 0 = running, 1 = finished training, 2 = terminated prematurely
uint16_t return_code;
if (!curr_level.terminated){
ConSplitsResult<RootContainer> con_splits_results = args[2].getAs<ByteString>();
uint16_t min_split = args[3].getAs<uint16_t>();
uint16_t min_bucket = args[4].getAs<uint16_t>();
uint16_t max_depth = args[5].getAs<uint16_t>();
bool subsample = args[6].getAs<bool>();
int num_random_features = args[7].getAs<int>();
bool finished = false;
if (!subsample) {
finished = dt.expand(curr_level, con_splits_results.con_splits,
min_split, min_bucket, max_depth);
} else {
finished = dt.expand_by_sampling(curr_level, con_splits_results.con_splits,
min_split, min_bucket, max_depth,
num_random_features);
}
return_code = finished ? FINISHED : NOT_FINISHED;
}
else{
return_code = TERMINATED; // indicates termination due to error
}
AnyType output_tuple;
output_tuple << dt.storage()
<< return_code
<< static_cast<uint16_t>(dt.tree_depth - 1);
return output_tuple;
} // apply function
// -------------------------------------------------------------------------
///////////////////////////////////////////////
// Functions to capture surrogate statistics //
///////////////////////////////////////////////
AnyType
compute_surr_stats_transition::run(AnyType & args){
MutableLevelState state = args[0].getAs<MutableByteString>();
LevelState::tree_type dt = args[1].getAs<ByteString>();
// need to change this according to the calling function
if (state.terminated) {
return args[0];
}
NativeIntegerVector cat_features;
if (args[2].isNull()){
cat_features.rebind(this->allocateArray<int>(0));
} else {
NativeIntegerVector xx_cat = args[2].getAs<NativeIntegerVector>();
cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size());
}
NativeColumnVector con_features;
if (args[3].isNull()){
con_features.rebind(this->allocateArray<double>(0));
} else {
NativeColumnVector xx_con = args[3].getAs<NativeColumnVector>();
con_features.rebind(xx_con.memoryHandle(), xx_con.size());
}
// cat_levels size = n_cat_features
NativeIntegerVector cat_levels;
if (args[4].isNull()){
cat_levels.rebind(this->allocateArray<int>(0));
}
else {
MutableNativeIntegerVector xx_cat = args[4].getAs<MutableNativeIntegerVector>();
for (Index i = 0; i < xx_cat.size(); i++)
xx_cat[i] -= 1; // ignore the last level
cat_levels.rebind(xx_cat.memoryHandle(), xx_cat.size());
}
// con_splits size = n_con_features x n_bins
ConSplitsResult<RootContainer> splits_results = args[5].getAs<MutableByteString>();
// tree_depth = 1 implies a single leaf node in the tree.
// We compute surrogates only for internal nodes. Hence we need
// the root be an internal node i.e. we need the tree_depth to be more than 1.
if (dt.tree_depth > 1){
if (state.empty()){
// To initialize the accumulator, first find which of the last
// level of internal nodes are actually reachable.
ColumnVector final_internal_feature_indices =
dt.feature_indices.segment(dt.feature_indices.size()/4,
dt.feature_indices.size()/4 + 1).cast<double>();
IntegerVector index_lookup(final_internal_feature_indices.size());
uint32_t n_internal_nodes_reachable = 0;
for (Index i=0; i < final_internal_feature_indices.size(); i++){
if (final_internal_feature_indices(i) >= 0){
index_lookup(i) = static_cast<int>(n_internal_nodes_reachable);
n_internal_nodes_reachable++;
}
else{
index_lookup(i) = -1;
}
}
// 1. We need to compute stats for parent of each leaf.
// Hence the tree_depth is decremented by 1.
// 2. We store 2 values for each surrogate split
// 1st value is for <= split; 2nd value is for > split
// (hence stats_per_split = 2)
state.rebind(static_cast<uint16_t>(splits_results.con_splits.cols()),
static_cast<uint16_t>(cat_features.size()),
static_cast<uint16_t>(con_features.size()),
static_cast<uint32_t>(cat_levels.sum()),
static_cast<uint16_t>(dt.tree_depth - 1),
2,
false, // dummy, only used in compute_leaf_stat
n_internal_nodes_reachable
);
for (Index i = 0; i < state.stats_lookup.size(); i++)
state.stats_lookup(i) = index_lookup(i);
// compute cumulative sum of the levels of the categorical variables
int current_sum = 0;
for (Index i=0; i < state.n_cat_features; i++){
current_sum += cat_levels(i);
state.cat_levels_cumsum(i) = current_sum;
}
}
const int dup_count = args[6].getAs<int>();
state << MutableLevelState::surr_tuple_type(
dt, cat_features, con_features, cat_levels, splits_results.con_splits, dup_count);
}
return state.storage();
}
AnyType
dt_surr_apply::run(AnyType & args){
MutableTree dt = args[0].getAs<MutableByteString>();
LevelState curr_level_surr = args[1].getAs<ByteString>();
if (!curr_level_surr.terminated && dt.max_n_surr > 0){
ConSplitsResult<RootContainer> con_splits_results =
args[2].getAs<ByteString>();
dt.pickSurrogates(curr_level_surr, con_splits_results.con_splits);
}
return dt.storage();
} // apply function
// -------------------------------------------------------------------------
/*
@brief Return the probabilities of classes as prediction
*/
AnyType
predict_dt_prob::run(AnyType &args){
if (args[0].isNull()){
return Null();
}
Tree dt = args[0].getAs<ByteString>();
NativeIntegerVector cat_features;
NativeColumnVector con_features;
try {
if (args[1].isNull()){
cat_features.rebind(this->allocateArray<int>(0));
}
else {
NativeIntegerVector xx_cat = args[1].getAs<NativeIntegerVector>();
cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size());
}
if (args[2].isNull()){
con_features.rebind(this->allocateArray<double>(0));
}
else {
NativeColumnVector xx_con = args[2].getAs<NativeColumnVector>();
con_features.rebind(xx_con.memoryHandle(), xx_con.size());
}
} catch (const ArrayWithNullException &e) {
return Null();
}
ColumnVector prediction = dt.predict(cat_features, con_features);
return prediction;
}
// -------------------------------------------------------------------------
/*
@brief Return the regression prediction or class that corresponds to max prediction
*/
AnyType
predict_dt_response::run(AnyType &args){
if (args[0].isNull()){
return Null();
}
Tree dt = args[0].getAs<ByteString>();
NativeIntegerVector cat_features;
NativeColumnVector con_features;
try {
if (args[1].isNull()){
cat_features.rebind(this->allocateArray<int>(0));
}
else {
NativeIntegerVector xx_cat = args[1].getAs<NativeIntegerVector>();
cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size());
}
if (args[2].isNull()){
con_features.rebind(this->allocateArray<double>(0));
}
else {
NativeColumnVector xx_con = args[2].getAs<NativeColumnVector>();
con_features.rebind(xx_con.memoryHandle(), xx_con.size());
}
} catch (const ArrayWithNullException &e) {
// reach here only if surrogates are not used
return Null();
}
return dt.predict_response(cat_features, con_features);
}
AnyType
display_decision_tree::run(AnyType &args) {
Tree dt = args[0].getAs<ByteString>();
ArrayHandle<text*> cat_feature_names = args[1].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> con_feature_names = args[2].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> cat_levels_text = args[3].getAs<ArrayHandle<text*> >();
ArrayHandle<int> cat_n_levels = args[4].getAs<ArrayHandle<int> >();
ArrayHandle<text*> dependent_var_levels = args[5].getAs<ArrayHandle<text*> >();
std::string id_prefix = args[6].getAs<std::string>();
bool verbose = args[7].getAs<bool>();
string tree_str = dt.display(cat_feature_names, con_feature_names,
cat_levels_text, cat_n_levels,
dependent_var_levels, id_prefix, verbose);
return tree_str;
}
AnyType
display_decision_tree_surrogate::run(AnyType &args) {
Tree dt = args[0].getAs<ByteString>();
ArrayHandle<text*> cat_feature_names = args[1].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> con_feature_names = args[2].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> cat_levels_text = args[3].getAs<ArrayHandle<text*> >();
ArrayHandle<int> cat_n_levels = args[4].getAs<ArrayHandle<int> >();
return dt.surr_display(cat_feature_names, con_feature_names,
cat_levels_text, cat_n_levels);
}
AnyType
print_decision_tree::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
AnyType tuple;
tuple << static_cast<uint16_t>(dt.tree_depth - 1)
<< dt.feature_indices
<< dt.feature_thresholds
<< dt.is_categorical
<< dt.predictions
<< dt.surr_indices
<< dt.surr_thresholds
<< dt.surr_status;
return tuple;
}
AnyType
compute_variable_importance::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
const int n_cat_features = args[1].getAs<int>();
const int n_con_features = args[2].getAs<int>();
ColumnVector cat_var_importance = ColumnVector::Zero(n_cat_features);
ColumnVector con_var_importance = ColumnVector::Zero(n_con_features);
dt.computeVariableImportance(cat_var_importance, con_var_importance);
ColumnVector combined_var_imp(n_cat_features + n_con_features);
combined_var_imp << cat_var_importance, con_var_importance;
return combined_var_imp;
}
AnyType
display_text_tree::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
ArrayHandle<text*> cat_feature_names = args[1].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> con_feature_names = args[2].getAs<ArrayHandle<text*> >();
ArrayHandle<text*> cat_levels_text = args[3].getAs<ArrayHandle<text*> >();
ArrayHandle<int> cat_n_levels = args[4].getAs<ArrayHandle<int> >();
ArrayHandle<text*> dep_levels = args[5].getAs<ArrayHandle<text*> >();
return dt.print(0, cat_feature_names, con_feature_names, cat_levels_text,
cat_n_levels, dep_levels, 1u);
}
// ------------------------------------------------------------
// Prune the tree model using cost-complexity parameter
// ------------------------------------------------------------
// Remove me's subtree and make it a leaf
void mark_subtree_removal_recur(MutableTree &dt, int me) {
if (me < dt.predictions.rows() &&
dt.feature_indices(me) != dt.NODE_NON_EXISTING) {
int left = static_cast<int>(dt.trueChild(static_cast<Index>(me)));
int right = static_cast<int>(dt.falseChild(static_cast<Index>(me)));
mark_subtree_removal_recur(dt, left);
mark_subtree_removal_recur(dt, right);
dt.feature_indices(me) = dt.NODE_NON_EXISTING;
}
}
void mark_subtree_removal(MutableTree &dt, int me) {
mark_subtree_removal_recur(dt, me);
dt.feature_indices(me) = dt.FINISHED_LEAF;
}
// ------------------------------------------------------------
/*
* Data structure that contains the info for lower sub-trees
*/
struct SubTreeInfo {
/* id of the root node of the subtree */
int root_id;
/* number of node splits */
int n_split;
/* current node's own risk */
double risk;
/*
* accumulated risk of sub-tree with the current
* node being the root
*/
double sum_risk;
/* sub-tree's average risk improvement per split */
double complexity;
SubTreeInfo * left_child;
SubTreeInfo * right_child;
SubTreeInfo(int i, int n, double r, double s, double c):
root_id(i), n_split(n), risk(r), sum_risk(s), complexity(c){
left_child = NULL;
right_child = NULL;
}
};
// FIXME: Remove after finalzing code
template <class IterableContainer>
string print_debug_list(IterableContainer debug_list){
std::stringstream debug;
typename IterableContainer::iterator it;
for (it = debug_list.begin(); it != debug_list.end(); it++){
debug << std::setprecision(8) << *it << ", ";
}
return debug.str();
}
/*
* Pruning the tree by setting the pruned nodes'
* feature_indices value to be NODE_NON_EXISTING.
*
* Closely follow rpart's implementation. Please read the
* source code of rpart/src/partition.c
*/
SubTreeInfo prune_tree(MutableTree &dt,
int me,
double alpha,
double estimated_complexity,
std::vector<double> & node_complexities) {
if (me >= dt.feature_indices.size() || /* out of range */
dt.feature_indices(me) == dt.NODE_NON_EXISTING)
return SubTreeInfo(-1, 0, 0, 0, 0);
double risk = dt.computeRisk(me);
double adjusted_risk = risk > estimated_complexity ?
estimated_complexity : risk;
if (adjusted_risk <= alpha) {
/* If the current node's risk is smaller than alpha, then the risk can
* never decrease more than alpha by splitting. Remove the current
* node's subtree and make the current node a leaf.
*/
mark_subtree_removal(dt, me);
node_complexities[me] = alpha;
return SubTreeInfo(me, 0, risk, risk, alpha);
}
if (dt.feature_indices(me) >= 0) {
SubTreeInfo left = prune_tree(dt, 2*me+1, alpha,
adjusted_risk - alpha,
node_complexities);
double left_improve_per_split = (risk - left.sum_risk) / (left.n_split + 1);
double left_child_improve = risk - left.risk;
if (left_improve_per_split < left_child_improve)
left_improve_per_split = left_child_improve;
adjusted_risk = left_improve_per_split > estimated_complexity ?
estimated_complexity : left_improve_per_split;
SubTreeInfo right = prune_tree(dt, 2*me+2, alpha, adjusted_risk - alpha,
node_complexities);
/*
* Closely follow rpart's algorithm, in rpart/src/partition.c
*
* If the average improvement of risk per split is larger
* than the sub-tree's average improvement, the current
* split is important. And we need to manually increase
* the value of the current split's improvement, which
* aims at keeping the current split if possible.
*/
double left_risk = left.sum_risk;
double right_risk = right.sum_risk;
int left_n_split = left.n_split;
int right_n_split = right.n_split;
double tempcp = (risk - (left_risk + right_risk)) /
(left_n_split + right_n_split + 1);
if (right.complexity > left.complexity) {
if (tempcp > left.complexity) {
left_risk = left.risk;
left_n_split = 0;
tempcp = (risk - (left_risk + right_risk)) /
(left_n_split + right_n_split + 1);
if (tempcp > right.complexity) {
right_risk = right.risk;
right_n_split = 0;
}
}
} else if (tempcp > right.complexity) {
right_risk = right.risk;
right_n_split = 0;
tempcp = (risk - (left_risk + right_risk)) /
(left_n_split + right_n_split + 1);
if (tempcp > left.complexity) {
left_risk = left.risk;
left_n_split = 0;
}
}
double complexity = (risk - (left_risk + right_risk)) /
(left_n_split + right_n_split + 1);
if (complexity <= alpha) {
/* Prune this split by removing the subtree */
mark_subtree_removal(dt, me);
node_complexities[me] = alpha;
return SubTreeInfo(me, 0, risk, risk, alpha);
} else {
node_complexities[me] = complexity;
return SubTreeInfo(me, left_n_split + right_n_split + 1,
risk, left_risk + right_risk, complexity);
}
} // end of if (dt.feature_indices(me) >= 0)
else {
// node is a leaf node
node_complexities[me] = alpha;
return SubTreeInfo(me, 0, risk, risk, alpha);
}
}
// ------------------------------------------------------------
void make_cp_list(MutableTree & dt,
std::vector<double> & node_complexities,
const double & alpha,
std::list<double> & cp_list,
const double root_risk){
cp_list.clear();
cp_list.push_back(node_complexities[0] / root_risk);
for (uint i = 1; i < node_complexities.size(); i++){
Index parent_id = dt.parentIndex(i);
if (dt.feature_indices(i) != dt.NODE_NON_EXISTING &&
dt.feature_indices(parent_id) != dt.NODE_NON_EXISTING){
double parent_cp = node_complexities[parent_id];
if (node_complexities[i] > parent_cp)
node_complexities[i] = parent_cp;
double current_cp = node_complexities[i];
if (current_cp < alpha)
current_cp = alpha; // don't explore any cp less than alpha
if (current_cp < parent_cp){
// original complexity is scaled by root_risk. But user
// expects an unscaled cp value
current_cp /= root_risk;
std::list<double>::iterator it;
bool skip_cp = false;
for (it = cp_list.begin(); !skip_cp && it != cp_list.end(); it++){
if (fabs(current_cp - *it) < 1e-4)
skip_cp = true;
if (current_cp > *it)
break;
}
if (!skip_cp)
cp_list.insert(it, current_cp);
}
}
}
}
// -------------------------------------------------------------------------
AnyType
prune_and_cplist::run(AnyType &args){
MutableTree dt = args[0].getAs<MutableByteString>();
double cp = args[1].getAs<double>();
bool compute_cp_list = args[2].getAs<bool>();
// We use a scaled version of risk (similar to rpart's definition).
// The risk is relative to a tree with no splits (single node tree).
double root_risk = dt.computeRisk(0);
double alpha = cp * root_risk;
std::vector<double> node_complexities(dt.feature_indices.size(), alpha);
prune_tree(dt, 0, alpha, root_risk, node_complexities);
// Get the new tree_depth after pruning
// Note: externally, tree_depth starts from 0 but DecisionTree assumes
// tree_depth starts from 1
uint16_t pruned_depth = static_cast<uint16_t>(dt.recomputeTreeDepth() - 1);
AnyType output_tuple;
if (compute_cp_list){
std::list<double> cp_list;
make_cp_list(dt, node_complexities, alpha, cp_list, root_risk);
// we copy to a vector since we currently only have << operator
// defined for a vector
ColumnVector cp_vector(cp_list.size());
std::list<double>::iterator it = cp_list.begin();
for (Index i = 0; it != cp_list.end(); it++, i++){
cp_vector(i) = *it;
}
output_tuple << dt.storage() << pruned_depth << cp_vector;
} else {
output_tuple << dt.storage() << pruned_depth;
}
return output_tuple;
}
// ------------------------------------------------------------
// Helper function for PivotalR
// Convert the result into rpart's frame item in the result
/*
* Fil a row of the frame matrix using data in tree
*/
void fill_row(MutableNativeMatrix &frame, Tree &dt, int me, int i, int n_cats) {
frame(i,0) = static_cast<double>(dt.encodeIndex(dt.feature_indices(me),
dt.is_categorical(me), n_cats));
frame(i,5) = 1; // complexity is not needed in plotting
frame(i,6) = 0; // ncompete is not needed in plotting
// How many surrogate variables have been computed for this split
int n_surrogates = 0;
for (int ii = 0; ii < dt.max_n_surr; ii ++) {
if (dt.surr_indices(me * dt.max_n_surr + ii) >= 0) {
n_surrogates ++;
}
}
frame(i,7) = n_surrogates;
if (dt.is_regression) {
frame(i,1) = dt.predictions(me,3); // n
frame(i,2) = dt.predictions(me,0); // wt
frame(i,3) = dt.computeRisk(me); // weighted variance
frame(i,4) = dt.predictions(me, 1) / dt.predictions(me,0); // yval
} else {
double total_records = dt.nodeWeightedCount(0);
double n_records_innode = static_cast<double>(dt.nodeCount(me));
double n_records_weighted_innode = dt.nodeWeightedCount(me);
int n_dep_levels = static_cast<int>(dt.n_y_labels);
// FIXME use weight sum as the total number
frame(i,1) = n_records_innode;
frame(i,2) = n_records_weighted_innode;
frame(i,3) = dt.computeMisclassification(me);
Index max_index;
dt.predictions.row(me).head(dt.n_y_labels).maxCoeff(&max_index);
// start from 1 to be consistent with R convention
frame(i,4) = static_cast<double>(max_index + 1);
frame(i,8) = frame(i,4);
for (int j = 0; j < n_dep_levels; ++j) {
frame(i,9 + j) = dt.predictions(me, j);
frame(i,9 + j + n_dep_levels) =
dt.predictions(me, j) / n_records_innode;
}
frame(i,9 + 2 * n_dep_levels) = n_records_innode / total_records;
}
}
// ------------------------------------------------------------
/*
* Recursively transverse the tree in a depth first way,
* and fill all row of frame at the same time
*/
void transverse_tree(Tree &dt, MutableNativeMatrix &frame,
int me, int &row, int n_cats) {
if (me < dt.feature_indices.size() && dt.feature_indices(me) != dt.NODE_NON_EXISTING) {
fill_row(frame, dt, me, row++, n_cats);
transverse_tree(dt, frame, static_cast<int>(dt.falseChild(me)), row, n_cats);
transverse_tree(dt, frame, static_cast<int>(dt.trueChild(me)), row, n_cats);
}
}
// ------------------------------------------------------------
AnyType convert_to_rpart_format::run(AnyType &args) {
Tree dt = args[0].getAs<ByteString>();
int n_cats = args[1].getAs<int>();
// number of nodes in the tree
int n_nodes = 0;
for (int i = 0; i < dt.feature_indices.size(); ++i) {
if (dt.feature_indices(i) != dt.NODE_NON_EXISTING) n_nodes++;
}
// number of columns in rpart frame
int n_col;
if (dt.is_regression) {
n_col = 8;
} else {
n_col = 10 + 2 * dt.n_y_labels;
}
MutableNativeMatrix frame(this->allocateArray<double>(
n_col, n_nodes), n_nodes, n_col);
int row = 0;
transverse_tree(dt, frame, 0, row, n_cats);
return frame;
}
// transverse the tree to get the internal nodes' split thresholds
void
transverse_tree_thresh(const Tree &dt, MutableNativeMatrix &thresh,
int me, int &row, int n_cats) {
// only for non-leaf nodes
if (dt.feature_indices(me) >= 0) {
// primary
thresh(row,0) = dt.encodeIndex(dt.feature_indices(me),
dt.is_categorical(me),
n_cats);
thresh(row,1) = dt.feature_thresholds(me);
row++;
// surrogates
for (int ii = 0; ii < dt.max_n_surr; ii ++) {
int surr_ii = me * dt.max_n_surr + ii;
if (dt.surr_indices(surr_ii) >= 0) {
thresh(row,0) = dt.encodeIndex(dt.surr_indices(surr_ii),
dt.surr_status(surr_ii) == 1 || dt.surr_status(surr_ii) == -1,
n_cats);
thresh(row,1) = dt.surr_thresholds(surr_ii);
row++;
}
}
transverse_tree_thresh(dt, thresh, static_cast<int>(dt.falseChild(me)), row, n_cats);
transverse_tree_thresh(dt, thresh, static_cast<int>(dt.trueChild(me)), row, n_cats);
}
}
// ------------------------------------------------------------
AnyType
get_split_thresholds::run(AnyType &args) {
Tree dt = args[0].getAs<ByteString>();
int n_cats = args[1].getAs<int>();
// number of internal nodes
int in_nodes = 0;
// count how many surrogate variables in the whole tree
int tot_surr_n = 0;
for (int i = 0; i < dt.feature_indices.size(); ++i) {
if (dt.feature_indices(i) >= 0) {
in_nodes++;
for (int ii = 0; ii < dt.max_n_surr; ii ++) {
if (dt.surr_indices(i * dt.max_n_surr + ii) >= 0) {
tot_surr_n ++;
}
}
}
}
MutableNativeMatrix thresh(
this->allocateArray<double>(2, in_nodes + tot_surr_n),
in_nodes + tot_surr_n,
2);
int row = 0;
transverse_tree_thresh(dt, thresh, 0, row, n_cats);
return thresh;
}
// ------------------------------------------------------------
/*
* PivotalR: randomForest
* Fil a row of the frame matrix using data in tree
*/
void fill_one_row(MutableNativeMatrix &frame, Tree &dt, int me, int i,
int &node_index) {
int feature_index = dt.feature_indices(me);
if (feature_index == dt.FINISHED_LEAF) {
frame(i,0) = 0;
frame(i,1) = 0;
frame(i,2) = 0;
frame(i,4) = -1;
node_index--;
} else {
frame(i,0) = node_index * 2;
frame(i,1) = node_index * 2 + 1;
frame(i,4) = 1;
}
frame(i,2) = feature_index;
frame(i,3) = dt.feature_thresholds(me);
if (dt.is_regression) {
frame(i,5) = dt.predictions(me,1) / dt.predictions(me,0); // yval
} else {
Index max_index;
dt.predictions.row(me).head(dt.n_y_labels).maxCoeff(&max_index);
// start from 1 to be consistent with R convention
frame(i,5) = static_cast<double>(max_index + 1);
}
}
/*
* PivotalR: randomForest
* Convert to R's randomForest format for getTree(..) function
*
*/
AnyType
convert_to_random_forest_format::run(AnyType &args) {
Tree dt = args[0].getAs<ByteString>();
// number of nodes in the tree
int n_nodes = 0;
for (int i = 0; i < dt.feature_indices.size(); ++i) {
if (dt.feature_indices(i) != dt.NODE_NON_EXISTING) n_nodes++;
}
// number of columns in randomForest frame
MutableNativeMatrix frame(this->allocateArray<double>(
6, n_nodes), n_nodes, 6);
int row = 0;
int node_index = 1;
for (int i = 0; i<n_nodes; i++,node_index++,row++) {
fill_one_row(frame, dt, i, row, node_index);
}
return frame;
}
// ------------------------------------------------------------
} // namespace recursive_partitioning
} // namespace modules
} // namespace madlib