| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| /* |
| * Factorization Machines. |
| */ |
| |
| forward = function(matrix[double] X, matrix[double] w0, matrix[double] W, matrix[double] V) |
| return (matrix[double] out) { |
| /* |
| * Computes the model. |
| * |
| * Reference: |
| * - Factorization Machines, Steffen Rendle. |
| * |
| * Inputs: |
| * - X : n examples with d features, of shape (n, d). |
| * - w0: the global bias, of shape (1,). |
| * - W : the strength of each feature, of shape (d, 1). |
| * - V : factorized interaction terms, of shape (d, k). |
| * |
| * Outputs: |
| * - out : target vector, of shape (n, 1). |
| */ |
| out = (X %*% W) + (0.5 * rowSums((X %*% V)^2 - (X^2 %*% V^2)) ) + w0 # shape (n, 1) |
| } |
| |
| backward = function(matrix[double] dout, matrix[double] X, matrix[double] w0, matrix[double] W, |
| matrix[double] V) |
| return (matrix[double] dw0, matrix[double] dW, matrix[double] dV) { |
| /* |
| * This function accepts the upstream gradients w.r.t. output target |
| * vector, and returns the gradients of the loss w.r.t. the |
| * parameters. |
| * |
| * Inputs: |
| * - dout : the gradient of the loss function w.r.t y, of |
| * shape (n, 1). |
| * - X, w0, W, V are as mentioned in the above forward function. |
| * |
| * Outputs: |
| * - dX : the gradient of loss function w.r.t X, of shape (n, d). |
| * - dw0: the gradient of loss function w.r.t w0, of shape (1,). |
| * - dW : the gradient of loss function w.r.t W, of shape (d, 1). |
| * - dV : the gradient of loss function w.r.t V, of shape (d, k). |
| */ |
| n = nrow(X) |
| d = ncol(X) |
| k = ncol(V) |
| |
| # 1. gradient of target vector w.r.t. w0 |
| g_w0 = as.matrix(1) # shape (1, 1) |
| |
| ## gradient of loss function w.r.t. w0 |
| dw0 = colSums(dout) # shape (1, 1) |
| |
| # 2. gradient target vector w.r.t. W |
| g_W = X # shape (n, d) |
| |
| ## gradient of loss function w.r.t. W |
| dW = t(g_W) %*% dout # shape (d, 1) |
| |
| # TODO: VECTORIZE THE FOLLOWING CODE (https://issues.apache.org/jira/browse/SYSTEMML-2102) |
| # 3. gradient of target vector w.r.t. V |
| # First term -> g_V1 = t(X) %*% (X %*% V) # shape (d, k) |
| |
| ## gradient of loss function w.r.t. V |
| # First term -> t(X) %*% X %*% V |
| |
| |
| # Second term -> V(i,f) * (X(i))^2 |
| Xt = t( X^2 ) %*% dout # shape (d,1) |
| |
| g_V2 = Xt[1,] %*% V[1,] |
| |
| for (i in 2:d) { |
| tmp = Xt[i,] %*% V[i,] |
| g_V2 = rbind(g_V2, tmp) |
| } |
| |
| xv = X %*% V |
| |
| g_V1 = dout[,1] * xv[,1] |
| |
| for (j in 2:k) { |
| tmp1 = dout[,1] * xv[,k] |
| g_V1 = cbind(g_V1, tmp1) |
| } |
| |
| dV = (t(X) %*% g_V1) - g_V2 |
| # dV = mean(dout) * (t(X) %*% X %*%V) - g_V2 |
| } |
| |
| init = function(int d, int k) |
| return (matrix[double] w0, matrix[double] W, matrix[double] V) { |
| /* |
| * This function initializes the parameters. |
| * |
| * Inputs: |
| * - d: the number of features, is an integer. |
| * - k: the factorization dimensionality, is an integer. |
| * |
| * Outputs: |
| * - w0: the global bias, of shape (1,). |
| * - W : the strength of each feature, of shape (d, 1). |
| * - V : factorized interaction terms, of shape (d, k). |
| */ |
| w0 = matrix(0, rows=1, cols=1) |
| W = matrix(0, rows=d, cols=1) |
| V = rand(rows=d, cols=k, min=0.0, max=1.0, pdf="uniform", sparsity=.08) |
| } |
| |