blob: a003f26f7d42c7e94352ee12133b488b48ecf78b [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script implements random forest prediction for recoded and binned
# categorical and numerical input features.
#
# INPUT:
# ------------------------------------------------------------------------------
# X Feature matrix in recoded/binned representation
# y Label matrix in recoded/binned representation,
# optional for accuracy evaluation
# ctypes Row-Vector of column types [1 scale/ordinal, 2 categorical]
# M Matrix M holding the learned trees (one tree per row),
# see randomForest() for the detailed tree representation.
# verbose Flag indicating verbose debug output
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# yhat Label vector of predictions
# ------------------------------------------------------------------------------
m_randomForestPredict = function(Matrix[Double] X, Matrix[Double] y = matrix(0,0,0),
Matrix[Double] ctypes, Matrix[Double] M, Boolean verbose = FALSE)
return(Matrix[Double] yhat)
{
t1 = time();
classify = as.scalar(ctypes[1,ncol(X)+1]) == 2;
yExists = (nrow(X)==nrow(y));
if(verbose) {
print("randomForestPredict: called for batch of "+nrow(X)+" rows, model of "
+nrow(M)+" trees, and with labels-provided "+yExists+".");
}
# scoring of num_tree decision trees
Ytmp = matrix(0, rows=nrow(M), cols=nrow(X));
parfor(i in 1:nrow(M)) {
if( verbose )
print("randomForest: start scoring tree "+i+"/"+nrow(M)+".");
# step 1: sample features (consistent with training)
I2 = M[i, 1:ncol(X)];
Xi = removeEmpty(target=X, margin="cols", select=I2);
# step 2: score decision tree
t2 = time();
ret = decisionTreePredict(X=Xi, M=M[i,ncol(X)+1:ncol(M)], ctypes=ctypes, strategy="TT");
Ytmp[i,1:nrow(ret)] = t(ret);
if( verbose )
print("-- ["+i+"] scored decision tree in "+(time()-t2)/1e9+" seconds.");
}
# label aggregation (majority voting / average)
yhat = matrix(0, nrow(X), 1);
if( classify ) {
parfor(i in 1:nrow(X))
yhat[i,1] = rowIndexMax(t(table(Ytmp[,i],1)));
}
else {
yhat = t(colSums(Ytmp)/nrow(M));
}
# summary statistics
if( yExists & verbose ) {
if( classify )
print("Accuracy (%): " + (sum(yhat == y) / nrow(y) * 100));
else
lmPredictStats(yhat, y, FALSE);
}
if(verbose) {
print("randomForest: scored batch of "+nrow(X)+" rows with "+nrow(M)+" trees in "+(time()-t1)/1e9+" seconds.");
}
}