blob: 5bf4390cd614eae23e838fb48f38e8e8c8387772 [file] [log] [blame]
#' Helper function to create a customized metric
#'
#' @export
mx.metric.custom <- function(name, feval) {
init <- function() {
c(0, 0)
}
update <- function(label, pred, state) {
m <- feval(as.array(label), as.array(pred))
state <- c(state[[1]] + 1, state[[2]] + m)
return(state)
}
get <- function(state) {
list(name=name, value=(state[[2]]/state[[1]]))
}
ret <- (list(init=init, update=update, get=get))
class(ret) <- "mx.metric"
return(ret)
}
#' Accuracy metric for classification
#'
#' @export
mx.metric.accuracy <- mx.metric.custom("accuracy", function(label, pred) {
ypred = max.col(t(as.array(pred)), tie="first")
return(sum((as.array(label) + 1) == ypred) / length(label))
})
#' Helper function for top-k accuracy
is.num.in.vect <- function(vect, num){
resp <- any(is.element(vect, num))
return(resp)
}
#' Top-k accuracy metric for classification
#'
#' @export
mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, pred, top_k = 5) {
if(top_k == 1){
return(mx.metric.accuracy(label,pred))
} else{
ypred <- apply(pred,2,function(x) order(x, decreasing=TRUE)[1:top_k])
ans <- apply(ypred, 2, is.num.in.vect, num = as.array(label + 1))
acc <- sum(ans)/length(label)
return(acc)
}
})
#' MSE (Mean Squared Error) metric for regression
#'
#' @export
mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
res <- mean((label-pred)^2)
return(res)
})
#' RMSE (Root Mean Squared Error) metric for regression
#'
#' @export
mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
res <- sqrt(mean((label-pred)^2))
return(res)
})
#' MAE (Mean Absolute Error) metric for regression
#'
#' @export
mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
res <- mean(abs(label-pred))
return(res)
})
#' RMSLE (Root Mean Squared Logarithmic Error) metric for regression
#'
#' @export
mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
res <- sqrt(mean((log(pred + 1) - log(label + 1))^2))
return(res)
})