| /* |
| * Copyright 2014 Google Inc. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // Author: stevensr@google.com (Ryan Stevens) |
| |
| #ifndef NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_ |
| #define NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_ |
| |
| #include <cstddef> |
| #include <vector> |
| |
| namespace net_instaweb { |
| |
| // A simple decision tree classifier that can classify samples given to it into |
| // different classes. Is not able to train itself, rather the decision tree is |
| // created from an exported DecisionTreeClassifier generated by python sklearn |
| // (see http://scikit-learn.org/stable/modules/tree.html). |
| // |
| // Nodes are defined by a tuple: |
| // (feature_index, feature_threshold, confidence, left_child, right_child) |
| // To make a prediction with the tree, we start from the root node and make a |
| // move to either the left child if (X[feature_index] <= feature_threshold), |
| // else we move to the right child. Thus, all inner nodes of the tree must have |
| // two children. We repeat this until we reach a leaf node, which has a |
| // prediction confidence. The prediction confidence generally indicates how |
| // confident we are that the sample belongs in the positive class. Typically you |
| // would accept the sample into the postive class if the confidence > 0.5. |
| class DecisionTree { |
| public: |
| struct Node { |
| int feature_index; |
| double feature_threshold; |
| double confidence; |
| const Node* left; |
| const Node* right; |
| // Nodes must be strictly binary. Out of an excess of caution we |
| // consider any node with a single child to be a leaf, as that avoids |
| // crashes in Predict(...). |
| bool IsLeafNode() const { return (left == NULL || right == NULL); } |
| }; |
| |
| // The nodes variable is expected to be statically allocated in the code, and |
| // this class will not take ownership of it. |
| DecisionTree(const Node* nodes, int num_nodes); |
| ~DecisionTree(); |
| |
| int num_features() const { return num_features_; } |
| |
| // Predict whether this sample belongs in the positive or negative class, |
| // returning the confidence that it is in the positive class. The sample |
| // should contain measurements for each of the relevant features, indexed |
| // in the same order that was used to train the tree. |
| // TODO(stevensr): The ordering of measurements in the sample will become |
| // more concrete when the feature measurement logic is committed. |
| double Predict(std::vector<double> const& sample) const; |
| |
| private: |
| void SanityCheck() const; |
| void SanityCheckTraversal(const Node* cur, int* num_nodes) const; |
| |
| const Node* Root() const; |
| |
| const Node* nodes_; |
| int num_nodes_; |
| int num_features_; |
| }; |
| |
| } // namespace net_instaweb |
| |
| #endif // NET_INSTAWEB_REWRITER_PUBLIC_DECISION_TREE_H_ |