| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # This script can be used to compute label predictions |
| # Meant for use with an SVM model (learnt using m-svm.dml) on a held out test set |
| # |
| # Given ground truth labels, the script will compute an |
| # accuracy (%) for the predictions |
| # |
| # Example Usage: |
| # hadoop jar SystemDS.jar -f m-svm-predict.dml -nvargs X=data Y=labels scoring_only=FALSE model=model scores=scores accuracy=accuracy confusion=confusion fmt="text" |
| # |
| |
| cmdLine_Y = ifdef($Y, " ") |
| cmdLine_confusion = ifdef($confusion, " ") |
| cmdLine_accuracy = ifdef($accuracy, " ") |
| cmdLine_scores = ifdef($scores, " ") |
| cmdLine_scoring_only = ifdef($scoring_only, FALSE) |
| cmdLine_fmt = ifdef($fmt, "text") |
| |
| X = read($X); |
| W = read($model); |
| |
| dimensions = as.scalar(W[nrow(W),1]) |
| if(dimensions != ncol(X)) |
| stop("Stopping due to invalid input: Model dimensions do not seem to match input data dimensions") |
| |
| intercept = as.scalar(W[nrow(W)-1,1]) |
| W = W[1:(nrow(W)-2),] |
| |
| N = nrow(X); |
| num_classes = ncol(W) |
| m=ncol(X); |
| |
| b = matrix(0, rows=1, cols=num_classes) |
| if (intercept == 1) |
| b = W[m+1,] |
| |
| ones = matrix(1, rows=N, cols=1) |
| scores = X %*% W[1:m,] + ones %*% b; |
| |
| if(cmdLine_scores != " ") |
| write(scores, cmdLine_scores, format=cmdLine_fmt); |
| |
| if(!cmdLine_scoring_only){ |
| Y = read(cmdLine_Y); |
| |
| if(min(Y) < 1) |
| stop("Stopping due to invalid argument: Label vector (Y) must be recoded") |
| |
| pred = rowIndexMax(scores); |
| correct_percentage = sum((pred - Y) == 0) / N * 100; |
| |
| acc_str = "Accuracy (%): " + correct_percentage |
| print(acc_str) |
| if(cmdLine_accuracy != " ") |
| write(acc_str, cmdLine_accuracy) |
| |
| num_classes_ground_truth = max(Y) |
| if(num_classes < num_classes_ground_truth) |
| num_classes = num_classes_ground_truth |
| |
| if(cmdLine_confusion != " "){ |
| confusion_mat = table(Y, pred, num_classes, num_classes) |
| write(confusion_mat, cmdLine_confusion, format="csv") |
| } |
| } |