blob: b74566d2dcb010f1ff1e80067760f9b4e6cf8f0c [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* L1 loss function.
*/
forward = function(matrix[double] pred, matrix[double] y)
return (double loss) {
/*
* Computes the forward pass for an L1 loss function. The inputs
* consist of N examples, each with M dimensions to predict.
*
* ```
* L_i = sum_j(abs((pred_i)_j - (y_i)_j)) for all j.
* L = (1/N) sum(L_i) for i=1 to N
* ```
*
* In these equations, `L` is the total loss, `L_i` is the loss for
* example `i`, `y_i` is the scalar target, `pred_i` is the scalar
* prediction, and `N` is the number of examples.
*
* This can be interpreted as the negative log-likelihood assuming
* a Laplace distribution.
*
* Inputs:
* - pred: Predictions, of shape (N, M).
* - y: Targets, of shape (N, M).
*
* Outputs:
* - loss: Average loss.
*/
N = nrow(y)
losses = rowSums(abs(pred-y))
loss = sum(losses) / N
}
backward = function(matrix[double] pred, matrix[double] y)
return (matrix[double] dpred) {
/*
* Computes the backward pass for an L1 loss function. The inputs
* consist of N examples, each with M dimensions to predict.
*
* Inputs:
* - pred: Predictions, of shape (N, M).
* - y: Targets, of shape (N, M).
*
* Outputs:
* - dpred: Gradient wrt `pred`, of shape (N, M).
*/
N = nrow(y)
dpred = sign(pred-y) / N
}