blob: 5d62ac61fb7723a03ce82c1e5d997f97a69b0c73 [file] [log] [blame]
/*
* Copyright 2014 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Author: stevensr@google.com (Ryan Stevens)
#include "net/instaweb/rewriter/public/decision_tree.h"
#include "base/logging.h"
namespace net_instaweb {
DecisionTree::DecisionTree(const Node* nodes, int num_nodes)
: nodes_(nodes),
num_nodes_(num_nodes),
num_features_(0) {
#ifndef NDEBUG
DCHECK_LT(0, num_nodes_);
SanityCheck();
#endif
int max_feature_index = -1;
for (int i = 0; i < num_nodes_; ++i) {
if (nodes[i].feature_index > max_feature_index) {
max_feature_index = nodes[i].feature_index;
}
}
num_features_ = max_feature_index + 1;
}
DecisionTree::~DecisionTree() {}
const DecisionTree::Node* DecisionTree::Root() const {
return &nodes_[0];
}
double DecisionTree::Predict(std::vector<double> const& sample) const {
// num_features_ can be smaller than sample.size() if we are ignoring some
// signals (because they are new or considered irrelevant after training).
DCHECK_LE(num_features_, static_cast<int>(sample.size()));
const Node* cur = Root();
while (cur != NULL && !cur->IsLeafNode()) {
if (sample[cur->feature_index] <= cur->feature_threshold) {
cur = cur->left;
} else {
cur = cur->right;
}
}
CHECK(cur != NULL);
return cur->confidence;
}
void DecisionTree::SanityCheck() const {
// We will do a tree traversal to to ensure the following invariants about
// a constructed decision tree:
// 1) All nodes are reachable.
// 2) All nodes have 2 (inner nodes) or 0 (leaf nodes) children.
// 3) All inner nodes have a feature_index > 0.
// 4) All leaf nodes have a 0.0 <= confidence <= 1.0.
const Node* root = Root();
int num_observed_nodes = 0;
SanityCheckTraversal(root, &num_observed_nodes);
DCHECK_GE(num_observed_nodes, num_nodes_) << "Unreachable nodes";
DCHECK_LE(num_observed_nodes, num_nodes_) << "Extraneous nodes";
}
void DecisionTree::SanityCheckTraversal(const Node* cur, int* num_nodes) const {
(*num_nodes)++;
DCHECK((cur->left != NULL && cur->right != NULL) ||
(cur->left == NULL && cur->right == NULL))
<< "Inner node has one child";
if (cur->IsLeafNode()) {
DCHECK(cur->confidence >= 0.0 && cur->confidence <= 1.0)
<< "Invalid confidence " << cur->confidence;
} else {
DCHECK_LE(0, cur->feature_index) << "Invalid feature index";
SanityCheckTraversal(cur->left, num_nodes);
SanityCheckTraversal(cur->right, num_nodes);
}
}
} // namespace net_instaweb