| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # |
| # This script computes an approximate factorization of a low-rank matrix X into two matrices U and V |
| # using the Alternating-Least-Squares (ALS) algorithm with conjugate gradient. |
| # Matrices U and V are computed by minimizing a loss function (with regularization). |
| # |
| # INPUT PARAMETERS: |
| # --------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # --------------------------------------------------------------------------------------------- |
| # X String --- Location to read the input matrix X to be factorized |
| # rank Int 10 Rank of the factorization |
| # reg String "L2" Regularization: |
| # "L2" = L2 regularization; |
| # f (U, V) = 0.5 * sum (W * (U %*% V - X) ^ 2) |
| # + 0.5 * lambda * (sum (U ^ 2) + sum (V ^ 2)) |
| # "wL2" = weighted L2 regularization |
| # f (U, V) = 0.5 * sum (W * (U %*% V - X) ^ 2) |
| # + 0.5 * lambda * (sum (U ^ 2 * row_nonzeros) |
| # + sum (V ^ 2 * col_nonzeros)) |
| # lambda Double 0.000001 Regularization parameter, no regularization if 0.0 |
| # maxi Int 50 Maximum number of iterations |
| # check Boolean TRUE Check for convergence after every iteration, i.e., updating U and V once |
| # thr Double 0.0001 Assuming check is set to TRUE, the algorithm stops and convergence is declared |
| # if the decrease in loss in any two consecutive iterations falls below this threshold; |
| # if check is FALSE thr is ignored |
| # --------------------------------------------------------------------------------------------- |
| # OUTPUT: |
| # 1- An m x r matrix U, where r is the factorization rank |
| # 2- An r x n matrix V |
| |
| |
| m_alsCG = function(Matrix[Double] X, Integer rank = 10, String reg = "L2", Double lambda = 0.000001, Integer maxi = 50, Boolean check = TRUE, Double thr = 0.0001, Boolean verbose = TRUE) |
| return (Matrix[Double] U, Matrix[Double] V) |
| { |
| r = rank; |
| max_iter = maxi; |
| |
| ###### MAIN PART ###### |
| m = nrow (X); |
| n = ncol (X); |
| |
| # initializing factor matrices |
| U = rand (rows = m, cols = r, min = -0.5, max = 0.5); # mxr |
| V = rand (rows = n, cols = r, min = -0.5, max = 0.5); # nxr |
| |
| W = (X != 0); |
| |
| # check for regularization |
| row_nonzeros = matrix(0,rows=1,cols=1); |
| col_nonzeros = matrix(0,rows=1,cols=1); |
| if( reg == "L2" ) { |
| # Loss Function with L2: |
| # f (U, V) = 0.5 * sum (W * (U %*% V - X) ^ 2) |
| # + 0.5 * lambda * (sum (U ^ 2) + sum (V ^ 2)) |
| if( verbose ) |
| print ("BEGIN ALS-CG SCRIPT WITH NONZERO SQUARED LOSS + L2 WITH LAMBDA - " + lambda); |
| row_nonzeros = matrix(1, nrow(W), 1); |
| col_nonzeros = matrix(1, ncol(W), 1); |
| } |
| else if( reg == "wL2" ) { |
| # Loss Function with weighted L2: |
| # f (U, V) = 0.5 * sum (W * (U %*% V - X) ^ 2) |
| # + 0.5 * lambda * (sum (U ^ 2 * row_nonzeros) + sum (V ^ 2 * col_nonzeros)) |
| if( verbose ) |
| print ("BEGIN ALS-CG SCRIPT WITH NONZERO SQUARED LOSS + WEIGHTED L2 WITH LAMBDA - " + lambda); |
| row_nonzeros = rowSums(W); |
| col_nonzeros = t(colSums(W)); |
| } |
| else { |
| stop ("wrong regularization! " + reg); |
| } |
| |
| is_U = TRUE; # start optimizing U, alternated |
| maxinneriter = r ; # min (ncol (U), 15); |
| |
| loss_init = 0.0; # only used if check is TRUE |
| if( check ) { |
| loss_init = 0.5 * sum( (X != 0) * (U %*% t(V) - X) ^ 2); |
| loss_init = loss_init + 0.5 * lambda * (sum (U ^ 2 * row_nonzeros) + sum (V ^ 2 * col_nonzeros)); |
| if( verbose ) |
| print ("----- Initial train loss: " + loss_init + " -----"); |
| } |
| |
| it = 0; |
| converged = FALSE; |
| while( as.integer(it/2) < max_iter & ! converged ) { |
| it = it + 1; |
| if( is_U ) |
| G = ((X != 0) * (U %*% t(V) - X)) %*% V + lambda * U * row_nonzeros; |
| else |
| G = t(t(U) %*% ((X != 0) * (U %*% t(V) - X))) + lambda * V * col_nonzeros; |
| |
| R = -G; |
| S = R; |
| norm_G2 = sum (G ^ 2); |
| norm_R2 = norm_G2; |
| |
| inneriter = 1; |
| tt = 0.000000001; |
| while( norm_R2 > tt * norm_G2 & inneriter <= maxinneriter ) { |
| if( is_U ) { |
| HS = (W * (S %*% t(V))) %*% V + lambda * S * row_nonzeros; |
| alpha = norm_R2 / sum (S * HS); |
| U = U + alpha * S; # OK since U is not used in HS |
| } |
| else { |
| HS = t(t(U) %*% (W * (U %*% t(S)))) + lambda * S * col_nonzeros; |
| alpha = norm_R2 / sum (S * HS); |
| V = V + alpha * S; # OK since V is not used in HS |
| } |
| |
| R = R - alpha * HS; |
| old_norm_R2 = norm_R2; |
| norm_R2 = sum (R ^ 2); |
| S = R + (norm_R2 / old_norm_R2) * S; |
| inneriter = inneriter + 1; |
| } |
| |
| is_U = ! is_U; |
| |
| # check for convergence |
| if( check & (it%%2 == 0) ) { |
| loss_cur = 0.5 * sum( (X != 0) * (U %*% t(V) - X) ^ 2); |
| loss_cur = loss_cur + 0.5 * lambda * (sum (U ^ 2 * row_nonzeros) + sum (V ^ 2 * col_nonzeros)); |
| |
| loss_dec = (loss_init - loss_cur) / loss_init; |
| if( verbose ) |
| print ("Train loss at iteration (" + as.integer(it/2) + "): " + loss_cur + " loss-dec " + loss_dec); |
| if( loss_dec >= 0 & loss_dec < thr | loss_init == 0 ) { |
| if( verbose ) |
| print ("----- ALS-CG converged after " + as.integer(it/2) + " iterations!"); |
| converged = TRUE; |
| } |
| loss_init = loss_cur; |
| } |
| } |
| |
| if(verbose) { |
| if(check) |
| print ("----- Final train loss: " + loss_init + " -----"); |
| if(!converged ) |
| print ("Max iteration achieved but not converged!"); |
| } |
| |
| V = t(V); |
| } |