#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# This script can be used to compute label predictions
# Meant for use with an SVM model (learnt using m-svm.dml) on a held out test set
#
# Given ground truth labels, the script will compute an 
# accuracy (%) for the predictions
#
# Example Usage:
# hadoop jar SystemML.jar -f m-svm-predict.dml -nvargs X=data Y=labels model=model scores=scores accuracy=accuracy confusion=confusion fmt="text"
#													 

cmdLine_Y = ifdef($Y, " ")
cmdLine_confusion = ifdef($confusion, " ")
cmdLine_accuracy = ifdef($accuracy, " ")
cmdLine_scores = ifdef($scores, " ")
cmdLine_fmt = ifdef($fmt, "text")

X = read($X);
W = read($model);

dimensions = as.scalar(W[nrow(W),1])
if(dimensions != ncol(X))
	stop("Stopping due to invalid input: Model dimensions do not seem to match input data dimensions")

intercept = as.scalar(W[nrow(W)-1,1])
W = W[1:(nrow(W)-2),]

N = nrow(X);
num_classes = ncol(W)
m=ncol(X);

b = matrix(0, rows=1, cols=num_classes)
if (intercept == 1)
	b = W[m+1,]

ones = matrix(1, rows=N, cols=1)
scores = X %*% W[1:m,] + ones %*% b;
	
if(cmdLine_scores != " ")
	write(scores, cmdLine_scores, format=cmdLine_fmt);

if(cmdLine_Y != " "){
	y = read(cmdLine_Y);
	
	if(min(y) < 1)
		stop("Stopping due to invalid argument: Label vector (Y) must be recoded")
	
	pred = rowIndexMax(scores);
	correct_percentage = sum((pred - y) == 0) / N * 100;
	
	acc_str = "Accuracy (%): " + correct_percentage
	print(acc_str)
	if(cmdLine_accuracy != " ")
		write(acc_str, cmdLine_accuracy)

	num_classes_ground_truth = max(y)
	if(num_classes < num_classes_ground_truth)
		num_classes = num_classes_ground_truth

	if(cmdLine_confusion != " "){
		confusion_mat = table(y, pred, num_classes, num_classes)
		write(confusion_mat, cmdLine_confusion, format="csv")
	}
}
