blob: 1e0988d2a2b45ccf6a49d1f4e2a3a6d5bc7825a8 [file] [log] [blame]
/*
* Copyright 2014 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Author: stevensr@google.com (Ryan Stevens)
#ifndef NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_
#define NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_
#include <cstddef>
#include <vector>
namespace net_instaweb {
// A simple decision tree classifier that can classify samples given to it into
// different classes. Is not able to train itself, rather the decision tree is
// created from an exported DecisionTreeClassifier generated by python sklearn
// (see http://scikit-learn.org/stable/modules/tree.html).
//
// Nodes are defined by a tuple:
// (feature_index, feature_threshold, confidence, left_child, right_child)
// To make a prediction with the tree, we start from the root node and make a
// move to either the left child if (X[feature_index] <= feature_threshold),
// else we move to the right child. Thus, all inner nodes of the tree must have
// two children. We repeat this until we reach a leaf node, which has a
// prediction confidence. The prediction confidence generally indicates how
// confident we are that the sample belongs in the positive class. Typically you
// would accept the sample into the postive class if the confidence > 0.5.
class DecisionTree {
public:
struct Node {
int feature_index;
double feature_threshold;
double confidence;
const Node* left;
const Node* right;
// Nodes must be strictly binary. Out of an excess of caution we
// consider any node with a single child to be a leaf, as that avoids
// crashes in Predict(...).
bool IsLeafNode() const { return (left == NULL || right == NULL); }
};
// The nodes variable is expected to be statically allocated in the code, and
// this class will not take ownership of it.
DecisionTree(const Node* nodes, int num_nodes);
~DecisionTree();
int num_features() const { return num_features_; }
// Predict whether this sample belongs in the positive or negative class,
// returning the confidence that it is in the positive class. The sample
// should contain measurements for each of the relevant features, indexed
// in the same order that was used to train the tree.
// TODO(stevensr): The ordering of measurements in the sample will become
// more concrete when the feature measurement logic is committed.
double Predict(std::vector<double> const& sample) const;
private:
void SanityCheck() const;
void SanityCheckTraversal(const Node* cur, int* num_nodes) const;
const Node* Root() const;
const Node* nodes_;
int num_nodes_;
int num_features_;
};
} // namespace net_instaweb
#endif // NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_