blob: 15914f7bfea7629566f04437c2b35c10ec99cd30 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* Log loss function.
*/
forward = function(matrix[double] pred, matrix[double] y)
return (double loss) {
/*
* Computes the forward pass for a log loss function.
*
* ```
* L_i = -y_i*log(pred_i) - (1-y_i)*log(1-pred_i)
* L = (1/N) sum(L_i) for i=1 to N
* ```
*
* In these equations, `L` is the total loss, `L_i` is the loss for
* example `i`, `y_i` is the binary target, `pred_i` is probability
* of the true class (i.e. `y=1`), and `N` is the number of examples.
*
* This can be interpreted as the negative log-likelihood assuming
* a Bernoulli distribution.
*
* Inputs:
* - pred: Predictions, of shape (N, 1).
* Predictions should be probabilities of the true
* class (i.e. probability of `y=1`).
* - y: Targets, of shape (N, 1).
* Targets should be binary in the set {0, 1}.
*
* Outputs:
* - loss: Average loss.
*/
N = nrow(y)
losses = -y*log(pred) - (1-y)*log(1-pred)
loss = sum(losses) / N
}
backward = function(matrix[double] pred, matrix[double] y)
return (matrix[double] dpred) {
/*
* Computes the backward pass for a log loss function.
*
* Inputs:
* - pred: Predictions, of shape (N, 1).
* Predictions should be probabilities of the true
* class (i.e. probability of `y=1`).
* - y: Targets, of shape (N, 1).
* Targets should be binary in the set {0, 1}.
*
* Outputs:
* - dpred: Gradient wrt `pred`, of shape (N, 1).
*/
N = nrow(y)
dpred = (1/N) * (pred-y) / (pred*(1-pred))
}