blob: 2de05c7a09299f461903c10e01d96590436dc488 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# compute the weighted and simple accuracy for given predictions
# Built-in function Implements Multiple Imputation using Chained Equations (MICE)
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# y Double --- Ground truth (Actual Labels)
# yhat Double --- predictions (Predicted labels)
# isWeighted Boolean FALSE flag for weighted or non-weighted accuracy calculation
# ---------------------------------------------------------------------------------------------
#Output(s)
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# accuracy Double --- accuracy of the predicted labels
m_getAccuracy = function(Matrix[Double] y, Matrix[Double] yhat, Boolean isWeighted = FALSE)
return (Double accuracy)
{
if(isWeighted)
{
sum = sum(y == yhat)
accuracy = (sum/nrow(y)) * 100
}
else
{
n = nrow(y)
classes = table(y, 1)
class_weight = n/(nrow(classes) * classes)
resp = matrix(0, nrow(y), nrow(classes))
resp = resp + t(seq(1, nrow(classes)))
respY = resp == y
respYhat = resp == yhat
pred = respY * respYhat
classes = replace(target = classes, pattern = 0, replacement = 1)
accuracy = mean(colSums(pred)/t(classes)) * 100
}
}