blob: 18592147efb2a2cd58ab6acb648b50dff7de2271 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script computes the rating/scores for a given list of userIDs
# using 2 factor matrices L and R. We assume that all users have rates
# at least once and all items have been rates at least once.
#
# INPUT:
# --------------------------------------------------------------------
# userIDs Column vector of user-ids (n x 1)
# I Indicator matrix user-id x user-id to exclude from scoring
# L The factor matrix L: user-id x feature-id
# R The factor matrix R: feature-id x item-id
# --------------------------------------------------------------------
#
# OUTPUT:
# ----------------------------------------------------------------------------
# Y The output user-id/item-id/score#
# ----------------------------------------------------------------------------
m_alsPredict = function(Matrix[Double] userIDs, Matrix[Double] I, Matrix[Double] L, Matrix[Double] R)
return (Matrix[Double] Y)
{
n = nrow(userIDs)
X_user_max = max(userIDs);
if (X_user_max > nrow(L))
stop ("Predictions cannot be provided. Maximum user-id exceeds the number of users.");
if (ncol(L) != nrow(R))
stop ("Predictions cannot be provided. Number of columns of L don't match the number of columns of R.");
# creates projection matrix to select users
P = table(seq(1,n), userIDs, n, nrow(L));
# selects users from factor L and exclude list
Usel = P %*% L;
Isel = P %*% I;
# calculates scores for selected users and filter exclude list
Y = (Isel == 0) * (Usel %*% R);
}