| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # Implements scoring with a naive Bayes model learnt using |
| # naive-bayes.dml |
| # |
| # hadoop jar SystemDS.jar -f naive-bayes-predict.dml -nvargs X=data Y=labels prior=model_file1 conditionals=model_file2 probabilities=probabilities accuracy=accuracy confusion=confusion fmt="text" |
| # |
| |
| cmdLine_Y = ifdef($Y, " ") |
| cmdLine_accuracy = ifdef($accuracy, " ") |
| cmdLine_confusion = ifdef($confusion, " ") |
| cmdLine_fmt = ifdef($fmt, "text") |
| |
| D = read($X) |
| min_feature_val = min(D) |
| if(min_feature_val < 0) |
| stop("Stopping due to invalid argument: Multinomial naive Bayes is meant for count-based feature values, minimum value in X is negative") |
| |
| prior = read($prior) |
| |
| dimensions = as.scalar(prior[nrow(prior),1]) |
| prior = prior[1:(nrow(prior)-1),] |
| |
| conditionals = read($conditionals) |
| |
| numRows = nrow(D) |
| |
| ones = matrix(1, rows=numRows, cols=1) |
| D_w_ones = cbind(D, ones) |
| model = cbind(conditionals, prior) |
| |
| log_probs = D_w_ones %*% t(log(model)) |
| |
| |
| mx = rowMaxs(log_probs) |
| ones = matrix(1, rows=1, cols=nrow(prior)) |
| probs = log_probs - mx %*% ones |
| probs = exp(probs)/(rowSums(exp(probs)) %*% ones) |
| write(probs, $probabilities, format=cmdLine_fmt) |
| |
| |
| if(cmdLine_Y != " "){ |
| C = read(cmdLine_Y) |
| if(min(C) < 1) |
| stop("Stopping due to invalid argument: Label vector (Y) must be recoded") |
| |
| pred = rowIndexMax(log_probs) |
| acc = sum(pred == C) / numRows * 100 |
| |
| acc_str = "Accuracy (%): " + acc |
| print(acc_str) |
| if(cmdLine_accuracy != " ") |
| write(acc, cmdLine_accuracy) |
| |
| num_classes = nrow(prior) |
| num_classes_ground_truth = max(C) |
| if(num_classes < num_classes_ground_truth) |
| num_classes = num_classes_ground_truth |
| |
| if(cmdLine_confusion != " "){ |
| confusion_mat = table(pred, C, num_classes, num_classes) |
| write(confusion_mat, cmdLine_confusion, format="csv") |
| } |
| } |