blob: 6738d36e98f1872285cd45cbb993aff2e3749bfe [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This builtin function compute the weighted and simple accuracy for given predictions
#
# INPUT:
# --------------------------------------------------------------------------------------
# y Ground truth (Actual Labels)
# yhat Predictions (Predicted labels)
# isWeighted Flag for weighted or non-weighted accuracy calculation
# --------------------------------------------------------------------------------------
#
# OUTPUT:
# --------------------------------------------------------------------------------------------
# accuracy accuracy of the predicted labels
# --------------------------------------------------------------------------------------------
m_getAccuracy = function(Matrix[Double] y, Matrix[Double] yhat, Boolean isWeighted = FALSE)
return (Double accuracy)
{
if(!isWeighted)
{
sum = sum(y == yhat)
accuracy = (sum/nrow(y)) * 100
}
else
{
n = nrow(y)
classes = table(y, 1, max(y), 1)
resp = matrix(0, nrow(y), nrow(classes))
resp = resp + t(seq(1, nrow(classes)))
respY = resp == y
respYhat = resp == yhat
pred = respY * respYhat
classes = replace(target = classes, pattern = 0, replacement = 1)
accuracy = mean(colSums(pred)/t(classes)) * 100
}
}