blob: b687bfa77abcb5fd932a15df392b97e94717b906 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Implements scoring with a naive Bayes model learnt using
# naive-bayes.dml
#
# hadoop jar SystemML.jar -f naive-bayes-predict.dml -nvargs X=data Y=labels prior=model_file1 conditionals=model_file2 probabilities=probabilities accuracy=accuracy confusion=confusion fmt="text"
#
cmdLine_Y = ifdef($Y, " ")
cmdLine_accuracy = ifdef($accuracy, " ")
cmdLine_confusion = ifdef($confusion, " ")
cmdLine_fmt = ifdef($fmt, "text")
D = read($X)
min_feature_val = min(D)
if(min_feature_val < 0)
stop("Stopping due to invalid argument: Multinomial naive Bayes is meant for count-based feature values, minimum value in X is negative")
prior = read($prior)
dimensions = as.scalar(prior[nrow(prior),1])
prior = prior[1:(nrow(prior)-1),]
conditionals = read($conditionals)
numRows = nrow(D)
ones = matrix(1, rows=numRows, cols=1)
D_w_ones = append(D, ones)
model = append(conditionals, prior)
log_probs = D_w_ones %*% t(log(model))
mx = rowMaxs(log_probs)
ones = matrix(1, rows=1, cols=nrow(prior))
probs = log_probs - mx %*% ones
probs = exp(probs)/(rowSums(exp(probs)) %*% ones)
write(probs, $probabilities, format=cmdLine_fmt)
if(cmdLine_Y != " "){
C = read(cmdLine_Y)
if(min(C) < 1)
stop("Stopping due to invalid argument: Label vector (Y) must be recoded")
pred = rowIndexMax(log_probs)
acc = sum(pred == C) / numRows * 100
acc_str = "Accuracy (%): " + acc
print(acc_str)
if(cmdLine_accuracy != " ")
write(acc, cmdLine_accuracy)
num_classes = nrow(prior)
num_classes_ground_truth = max(C)
if(num_classes < num_classes_ground_truth)
num_classes = num_classes_ground_truth
if(cmdLine_confusion != " "){
confusion_mat = table(pred, C, num_classes, num_classes)
write(confusion_mat, cmdLine_confusion, format="csv")
}
}