blob: 026b88b27cbcc7ab2017fd160824dae27273a566 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
cmdLine_Y = ifdef($Y, " ")
cmdLine_confusion = ifdef($confusion, " ")
cmdLine_accuracy = ifdef($accuracy, " ")
cmdLine_scores = ifdef($scores, " ")
cmdLine_fmt = ifdef($fmt, "text")
X = read ($X);
N = max(X[,1]);
W = read($W);
dimensions = as.scalar(W[nrow(W),1])
intercept = as.scalar(W[nrow(W)-1,1])
W = W[1:(nrow(W)-2),]
#X = table(X[,1], X[,2], X[,3], N, dimensions)
#print(nrow(X) + " " + ncol(X) + " " + dimensions)
num_classes = ncol(W)
m=ncol(X);
b = matrix(0, rows=1, cols=num_classes)
if (intercept == 1)
b = W[m+1,]
ones = matrix(1, rows=N, cols=1)
scores = X %*% W[1:m,] + ones %*% b;
predicted_y = rowIndexMax(scores);
if(sum(predicted_y) != 0)
write(predicted_y, cmdLine_scores, format=cmdLine_fmt);
if(cmdLine_Y != " "){
y = read(cmdLine_Y);
if(min(y) < 1)
stop("Stopping due to invalid argument: Label vector (Y) must be recoded")
correct_percentage = sum(ppred(predicted_y - y, 0, "==")) / N * 100;
acc_str = "Accuracy (%): " + correct_percentage
print(acc_str)
if(cmdLine_accuracy != " ")
write(acc_str, cmdLine_accuracy)
num_classes_ground_truth = max(y)
if(num_classes < num_classes_ground_truth)
num_classes = num_classes_ground_truth
if(cmdLine_confusion != " "){
confusion_mat = table(predicted_y, y, num_classes, num_classes)
write(confusion_mat, cmdLine_confusion, format="csv")
}
}