blob: 3a42c7f48704b46ad111853c29f6bb3226413e9e [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script implements random forest prediction for recoded and binned
# categorical and numerical input features.
#
#
# .. code-block:: python
#
# >>> import numpy as np
# >>> from systemds.context import SystemDSContext
# >>> from systemds.operator.algorithm import randomForest, randomForestPredict
# >>>
# >>> # tiny toy dataset
# >>> X = np.array([[1],
# ... [2],
# ... [10],
# ... [11]], dtype=np.int64)
# >>> y = np.array([[1],
# ... [1],
# ... [2],
# ... [2]], dtype=np.int64)
# >>>
# >>> with SystemDSContext() as sds:
# ... X_sds = sds.from_numpy(X)
# ... y_sds = sds.from_numpy(y)
# ...
# ... ctypes = sds.from_numpy(np.array([[1, 2]], dtype=np.int64))
# ...
# ... # train a 4-tree forest (no sampling)
# ... M = randomForest(
# ... X_sds, y_sds, ctypes,
# ... num_trees = 4,
# ... sample_frac = 1.0,
# ... feature_frac = 1.0,
# ... max_depth = 3,
# ... min_leaf = 1,
# ... min_split = 2,
# ... seed = 42
# ... )
# ...
# ... preds = randomForestPredict(X_sds, ctypes, M).compute()
# ... print(preds)
# [[1.]
# [1.]
# [2.]
# [2.]]
#
#
# INPUT:
# ------------------------------------------------------------------------------
# X Feature matrix in recoded/binned representation
# y Label matrix in recoded/binned representation,
# optional for accuracy evaluation
# ctypes Row-Vector of column types [1 scale/ordinal, 2 categorical]
# M Matrix M holding the learned trees (one tree per row),
# see randomForest() for the detailed tree representation.
# verbose Flag indicating verbose debug output
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# yhat Label vector of predictions
# ------------------------------------------------------------------------------
m_randomForestPredict = function(Matrix[Double] X, Matrix[Double] y = matrix(0,0,0),
Matrix[Double] ctypes, Matrix[Double] M, Boolean verbose = FALSE)
return(Matrix[Double] yhat)
{
t1 = time();
classify = as.scalar(ctypes[1,ncol(X)+1]) == 2;
yExists = (nrow(X)==nrow(y));
if(verbose) {
print("randomForestPredict: called for batch of "+nrow(X)+" rows, model of "
+nrow(M)+" trees, and with labels-provided "+yExists+".");
}
# scoring of num_tree decision trees
Ytmp = matrix(0, rows=nrow(M), cols=nrow(X));
parfor(i in 1:nrow(M)) {
if( verbose )
print("randomForest: start scoring tree "+i+"/"+nrow(M)+".");
# step 1: sample features (consistent with training)
I2 = M[i, 1:ncol(X)];
Xi = removeEmpty(target=X, margin="cols", select=I2);
# step 2: score decision tree
t2 = time();
ret = decisionTreePredict(X=Xi, M=M[i,ncol(X)+1:ncol(M)], ctypes=ctypes, strategy="TT");
Ytmp[i,1:nrow(ret)] = t(ret);
if( verbose )
print("-- ["+i+"] scored decision tree in "+(time()-t2)/1e9+" seconds.");
}
# label aggregation (majority voting / average)
yhat = matrix(0, nrow(X), 1);
if( classify ) {
parfor(i in 1:nrow(X))
yhat[i,1] = rowIndexMax(t(table(Ytmp[,i],1)));
}
else {
yhat = t(colSums(Ytmp)/nrow(M));
}
# summary statistics
if( yExists & verbose ) {
if( classify )
print("Accuracy (%): " + (sum(yhat == y) / nrow(y) * 100));
else
lmPredictStats(yhat, y, FALSE);
}
if(verbose) {
print("randomForest: scored batch of "+nrow(X)+" rows with "+nrow(M)+" trees in "+(time()-t1)/1e9+" seconds.");
}
}