blob: 1cbd571f3229b7be04ebc14f1d1fddb46e11bc6e [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script computes the rating/scores for a given list of userIDs
# using 2 factor matrices L and R. We assume that all users have rates
# at least once and all items have been rates at least once.
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# userIDs Matrix --- Column vector of user-ids (n x 1)
# I Matrix --- Indicator matrix user-id x user-id to exclude from scoring
# L Matrix --- The factor matrix L: user-id x feature-id
# R Matrix --- The factor matrix R: feature-id x item-id
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# Y Matrix --- The output user-id/item-id/score
m_alsPredict = function(Matrix[Double] userIDs, Matrix[Double] I, Matrix[Double] L, Matrix[Double] R)
return (Matrix[Double] Y)
{
n = nrow(userIDs)
X_user_max = max(userIDs);
if (X_user_max > nrow(L))
stop ("Predictions cannot be provided. Maximum user-id exceeds the number of users.");
if (ncol(L) != nrow(R))
stop ("Predictions cannot be provided. Number of columns of L don't match the number of columns of R.");
# creates projection matrix to select users
P = table(seq(1,n), userIDs, n, nrow(L));
# selects users from factor L and exclude list
Usel = P %*% L;
Isel = P %*% I;
# calculates scores for selected users and filter exclude list
Y = (Isel == 0) * (Usel %*% R);
}