blob: a959836f3e8e1505a5a124cdf5ed90e5adcc765a [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script can be used to compute label predictions
# Meant for use with an SVM model (learnt using m-svm.dml) on a held out test set
#
# Given ground truth labels, the script will compute an
# accuracy (%) for the predictions
#
# Example Usage:
# hadoop jar SystemML.jar -f m-svm-predict.dml -nvargs X=data Y=labels model=model scores=scores accuracy=accuracy confusion=confusion fmt="text"
#
cmdLine_Y = ifdef($Y, " ")
cmdLine_confusion = ifdef($confusion, " ")
cmdLine_accuracy = ifdef($accuracy, " ")
cmdLine_scores = ifdef($scores, " ")
cmdLine_fmt = ifdef($fmt, "text")
X = read($X);
W = read($model);
dimensions = as.scalar(W[nrow(W),1])
if(dimensions != ncol(X))
stop("Stopping due to invalid input: Model dimensions do not seem to match input data dimensions")
intercept = as.scalar(W[nrow(W)-1,1])
W = W[1:(nrow(W)-2),]
N = nrow(X);
num_classes = ncol(W)
m=ncol(X);
b = matrix(0, rows=1, cols=num_classes)
if (intercept == 1)
b = W[m+1,]
ones = matrix(1, rows=N, cols=1)
scores = X %*% W[1:m,] + ones %*% b;
if(cmdLine_scores != " ")
write(scores, cmdLine_scores, format=cmdLine_fmt);
if(cmdLine_Y != " "){
y = read(cmdLine_Y);
if(min(y) < 1)
stop("Stopping due to invalid argument: Label vector (Y) must be recoded")
pred = rowIndexMax(scores);
correct_percentage = sum((pred - y) == 0) / N * 100;
acc_str = "Accuracy (%): " + correct_percentage
print(acc_str)
if(cmdLine_accuracy != " ")
write(acc_str, cmdLine_accuracy)
num_classes_ground_truth = max(y)
if(num_classes < num_classes_ground_truth)
num_classes = num_classes_ground_truth
if(cmdLine_confusion != " "){
confusion_mat = table(y, pred, num_classes, num_classes)
write(confusion_mat, cmdLine_confusion, format="csv")
}
}