blob: 67b9870c7f819a4be4ea0255838ed2aaefa47957 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* L2 loss function.
*/
forward = function(matrix[double] pred, matrix[double] y)
return (double loss) {
/*
* Computes the forward pass for an L2 loss function. The inputs
* consist of N examples, each with M dimensions to predict.
*
* ```
* L_i = (1/2) norm(pred_i - y_i)^2
* L = (1/N) sum(L_i) for i=1 to N
* ```
*
* In these equations, `L` is the total loss, `L_i` is the loss for
* example `i`, `y_i` is the scalar target, `pred_i` is the scalar
* prediction, and `N` is the number of examples.
*
* This can be interpreted as the negative log-likelihood assuming
* a Gaussian distribution.
*
* Inputs:
* - pred: Predictions, of shape (N, M).
* - y: Targets, of shape (N, M).
*
* Outputs:
* - loss: Average loss.
*/
N = nrow(y)
losses = 0.5 * rowSums((pred-y)^2)
loss = sum(losses) / N
}
backward = function(matrix[double] pred, matrix[double] y)
return (matrix[double] dpred) {
/*
* Computes the backward pass for an L2 loss function. The inputs
* consist of N examples, each with M dimensions to predict.
*
* Inputs:
* - pred: Predictions, of shape (N, M).
* - y: Targets, of shape (N, M).
*
* Outputs:
* - dpred: Gradient wrt `pred`, of shape (N, M).
*/
N = nrow(y)
dpred = (pred-y) / N
}