| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # |
| # THIS SCRIPT COMPUTES LABEL PREDICTIONS MEANT FOR USE WITH A DECISION TREE MODEL ON A HELD OUT TEST SET. |
| # |
| # INPUT PARAMETERS: |
| # --------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # --------------------------------------------------------------------------------------------- |
| # X String --- Location to read the test feature matrix X; note that X needs to be both recoded and dummy coded |
| # Y String " " Location to read the true label matrix Y if requested; note that Y needs to be both recoded and dummy coded |
| # R String " " Location to read matrix R which for each feature in X contains the following information |
| # - R[,1]: column ids |
| # - R[,2]: start indices |
| # - R[,3]: end indices |
| # If R is not provided by default all variables are assumed to be scale |
| # M String --- Location to read matrix M containing the learned tree in the following format |
| # - M[1,j]: id of node j (in a complete binary tree) |
| # - M[2,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0 |
| # - M[3,j]: Feature index of the feature that node j looks at if j is an internal node, otherwise 0 |
| # - M[4,j]: Type of the feature that node j looks at if j is an internal node: 1 for scale and 2 for categorical features, |
| # otherwise the label that leaf node j is supposed to predict |
| # - M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale, otherwise the size of the subset of values |
| # stored in rows 6,7,... if j is categorical |
| # If j is a leaf node: number of misclassified samples reaching at node j |
| # - M[6:,j]: If j is an internal node: Threshold the example's feature value is compared to is stored at M[6,j] |
| # if the feature chosen for j is scale, otherwise if the feature chosen for j is categorical rows 6,7,... |
| # depict the value subset chosen for j |
| # If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0 |
| # P String --- Location to store the label predictions for X |
| # A String " " Location to write the test accuracy (%) for the prediction if requested |
| # CM String " " Location to write the confusion matrix if requested |
| # fmt String "text" The output format of the output, such as "text" or "csv" |
| # --------------------------------------------------------------------------------------------- |
| # OUTPUT: |
| # 1- Matrix Y containing the predicted labels for X |
| # 2- Test accuracy if requested |
| # 3- Confusion matrix C if requested |
| # ------------------------------------------------------------------------------------------- |
| # HOW TO INVOKE THIS SCRIPT - EXAMPLE: |
| # hadoop jar SystemDS.jar -f decision-tree-predict.dml -nvargs X=INPUT_DIR/X Y=INPUT_DIR/Y R=INPUT_DIR/R M=INPUT_DIR/model P=OUTPUT_DIR/predictions |
| # A=OUTPUT_DIR/accuracy CM=OUTPUT_DIR/confusion fmt=csv |
| |
| fileX = $X; |
| fileM = $M; |
| fileP = $P; |
| fileY = ifdef ($Y, " "); |
| fileR = ifdef ($R, " "); |
| fileCM = ifdef ($CM, " "); |
| fileA = ifdef ($A, " "); |
| fmtO = ifdef ($fmt, "text"); |
| X_test = read (fileX); |
| M = read (fileM); |
| |
| num_records = nrow (X_test); |
| Y_predicted = matrix (0, rows = num_records, cols = 1); |
| |
| R_cat = matrix (0, rows = 1, cols = 1); |
| R_scale = matrix (0, rows = 1, cols = 1); |
| |
| if (fileR != " ") { |
| R = read (fileR); |
| dummy_coded = (R[,2] != R[,3]); |
| R_scale = removeEmpty (target = R[,2] * (1 - dummy_coded), margin = "rows"); |
| R_cat = removeEmpty (target = R[,2:3] * dummy_coded, margin = "rows"); |
| } else { # only scale features available |
| R_scale = seq (1, ncol (X_test)); |
| } |
| |
| parfor (i in 1:num_records, check = 0) { |
| cur_sample = X_test[i,]; |
| cur_node_pos = 1; |
| label_found = FALSE; |
| while (!label_found) { |
| cur_feature = as.scalar (M[3,cur_node_pos]); |
| type_label = as.scalar (M[4,cur_node_pos]); |
| if (cur_feature == 0) { # leaf node |
| label_found = TRUE; |
| Y_predicted[i,] = type_label; |
| } else { |
| # determine type: 1 for scale, 2 for categorical |
| if (type_label == 1) { # scale feature |
| cur_start_ind = as.scalar (R_scale[cur_feature,]); |
| cur_value = as.scalar (cur_sample[,cur_start_ind]); |
| cur_split = as.scalar (M[6,cur_node_pos]); |
| if (cur_value < cur_split) { # go to left branch |
| cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]); |
| } else { # go to right branch |
| cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1; |
| } |
| } else if (type_label == 2) { # categorical feature |
| cur_start_ind = as.scalar (R_cat[cur_feature,1]); |
| cur_end_ind = as.scalar (R_cat[cur_feature,2]); |
| cur_value = as.scalar (rowIndexMax(cur_sample[,cur_start_ind:cur_end_ind])); # as.scalar (cur_sample[,cur_feature]); |
| cur_offset = as.scalar (M[5,cur_node_pos]); |
| value_found = sum (M[6:(6 + cur_offset - 1),cur_node_pos] == cur_value); |
| if (value_found) { # go to left branch |
| cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]); |
| } else { # go to right branch |
| cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1; |
| } |
| }}}} |
| |
| write (Y_predicted, fileP, format = fmtO); |
| |
| if (fileY != " ") { |
| Y_test = read (fileY); |
| num_classes = ncol (Y_test); |
| Y_test = rowSums (Y_test * t (seq (1, num_classes))); |
| result = (Y_test == Y_predicted); |
| result = sum (result); |
| accuracy = result / num_records * 100; |
| acc_str = "Accuracy (%): " + accuracy; |
| if (fileA != " ") { |
| write (acc_str, fileA, format = fmtO); |
| } else { |
| print (acc_str); |
| } |
| if (fileCM != " ") { |
| confusion_mat = table(Y_predicted, Y_test, num_classes, num_classes) |
| write(confusion_mat, fileCM, format = fmtO) |
| } |
| } |