blob: c375c13cce3c6e7341d0e61c5eeae48f587647f2 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
#
# THIS SCRIPT COMPUTES LABEL PREDICTIONS MEANT FOR USE WITH A DECISION TREE MODEL ON A HELD OUT TEST SET.
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X String --- Location to read the test feature matrix X; note that X needs to be both recoded and dummy coded
# Y String " " Location to read the true label matrix Y if requested; note that Y needs to be both recoded and dummy coded
# R String " " Location to read matrix R which for each feature in X contains the following information
# - R[,1]: column ids
# - R[,2]: start indices
# - R[,3]: end indices
# If R is not provided by default all variables are assumed to be scale
# M String --- Location to read matrix M containing the learned tree in the following format
# - M[1,j]: id of node j (in a complete binary tree)
# - M[2,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0
# - M[3,j]: Feature index of the feature that node j looks at if j is an internal node, otherwise 0
# - M[4,j]: Type of the feature that node j looks at if j is an internal node: 1 for scale and 2 for categorical features,
# otherwise the label that leaf node j is supposed to predict
# - M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale, otherwise the size of the subset of values
# stored in rows 6,7,... if j is categorical
# If j is a leaf node: number of misclassified samples reaching at node j
# - M[6:,j]: If j is an internal node: Threshold the example's feature value is compared to is stored at M[6,j]
# if the feature chosen for j is scale, otherwise if the feature chosen for j is categorical rows 6,7,...
# depict the value subset chosen for j
# If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0
# P String --- Location to store the label predictions for X
# A String " " Location to write the test accuracy (%) for the prediction if requested
# CM String " " Location to write the confusion matrix if requested
# fmt String "text" The output format of the output, such as "text" or "csv"
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# 1- Matrix Y containing the predicted labels for X
# 2- Test accuracy if requested
# 3- Confusion matrix C if requested
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemML.jar -f decision-tree-predict.dml -nvargs X=INPUT_DIR/X Y=INPUT_DIR/Y R=INPUT_DIR/R M=INPUT_DIR/model P=OUTPUT_DIR/predictions
# A=OUTPUT_DIR/accuracy CM=OUTPUT_DIR/confusion fmt=csv
fileX = $X;
fileM = $M;
fileP = $P;
fileY = ifdef ($Y, " ");
fileR = ifdef ($R, " ");
fileCM = ifdef ($CM, " ");
fileA = ifdef ($A, " ");
fmtO = ifdef ($fmt, "text");
X_test = read (fileX);
M = read (fileM);
num_records = nrow (X_test);
Y_predicted = matrix (0, rows = num_records, cols = 1);
R_cat = matrix (0, rows = 1, cols = 1);
R_scale = matrix (0, rows = 1, cols = 1);
if (fileR != " ") {
R = read (fileR);
dummy_coded = (R[,2] != R[,3]);
R_scale = removeEmpty (target = R[,2] * (1 - dummy_coded), margin = "rows");
R_cat = removeEmpty (target = R[,2:3] * dummy_coded, margin = "rows");
} else { # only scale features available
R_scale = seq (1, ncol (X_test));
}
parfor (i in 1:num_records, check = 0) {
cur_sample = X_test[i,];
cur_node_pos = 1;
label_found = FALSE;
while (!label_found) {
cur_feature = as.scalar (M[3,cur_node_pos]);
type_label = as.scalar (M[4,cur_node_pos]);
if (cur_feature == 0) { # leaf node
label_found = TRUE;
Y_predicted[i,] = type_label;
} else {
# determine type: 1 for scale, 2 for categorical
if (type_label == 1) { # scale feature
cur_start_ind = as.scalar (R_scale[cur_feature,]);
cur_value = as.scalar (cur_sample[,cur_start_ind]);
cur_split = as.scalar (M[6,cur_node_pos]);
if (cur_value < cur_split) { # go to left branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]);
} else { # go to right branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1;
}
} else if (type_label == 2) { # categorical feature
cur_start_ind = as.scalar (R_cat[cur_feature,1]);
cur_end_ind = as.scalar (R_cat[cur_feature,2]);
cur_value = as.scalar (rowIndexMax(cur_sample[,cur_start_ind:cur_end_ind])); # as.scalar (cur_sample[,cur_feature]);
cur_offset = as.scalar (M[5,cur_node_pos]);
value_found = sum (M[6:(6 + cur_offset - 1),cur_node_pos] == cur_value);
if (value_found) { # go to left branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]);
} else { # go to right branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1;
}
}}}}
write (Y_predicted, fileP, format = fmtO);
if (fileY != " ") {
Y_test = read (fileY);
num_classes = ncol (Y_test);
Y_test = rowSums (Y_test * t (seq (1, num_classes)));
result = (Y_test == Y_predicted);
result = sum (result);
accuracy = result / num_records * 100;
acc_str = "Accuracy (%): " + accuracy;
if (fileA != " ") {
write (acc_str, fileA, format = fmtO);
} else {
print (acc_str);
}
if (fileCM != " ") {
confusion_mat = table(Y_predicted, Y_test, num_classes, num_classes)
write(confusion_mat, fileCM, format = fmtO)
}
}