blob: f3c881ea9f6f16475a2bbc27b565daf5ccf736a3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.recommendation
// This must be the same package as Spark's MatrixFactorizationModel because
// MatrixFactorizationModel's constructor is private and we are using
// its constructor in order to save and load the model
import com.github.fommil.netlib.BLAS.{getInstance => blas} // ADDED
import org.apache.predictionio.examples.recommendation.ALSAlgorithmParams
import org.apache.predictionio.controller.PersistentModel
import org.apache.predictionio.controller.PersistentModelLoader
import org.apache.predictionio.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
class ALSModel(
override val rank: Int,
override val userFeatures: RDD[(Int, Array[Double])],
override val productFeatures: RDD[(Int, Array[Double])],
val userStringIntMap: BiMap[String, Int],
val itemStringIntMap: BiMap[String, Int])
extends MatrixFactorizationModel(rank, userFeatures, productFeatures)
with PersistentModel[ALSAlgorithmParams] {
// ADDED
def recommendProductsWithFilter(user: Int, num: Int, productIdFilter: Set[Int]) = {
val filteredProductFeatures = productFeatures
.filter { case (id, _) => !productIdFilter.contains(id) } // (*)
recommend(userFeatures.lookup(user).head, filteredProductFeatures, num)
.map(t => Rating(user, t._1, t._2))
}
// ADDED
private def recommend(
recommendToFeatures: Array[Double],
recommendableFeatures: RDD[(Int, Array[Double])],
num: Int): Array[(Int, Double)] = {
val scored = recommendableFeatures.map { case (id, features) =>
(id, blas.ddot(features.length, recommendToFeatures, 1, features, 1))
}
scored.top(num)(Ordering.by(_._2))
}
def save(id: String, params: ALSAlgorithmParams,
sc: SparkContext): Boolean = {
sc.parallelize(Seq(rank)).saveAsObjectFile(s"/tmp/${id}/rank")
userFeatures.saveAsObjectFile(s"/tmp/${id}/userFeatures")
productFeatures.saveAsObjectFile(s"/tmp/${id}/productFeatures")
sc.parallelize(Seq(userStringIntMap))
.saveAsObjectFile(s"/tmp/${id}/userStringIntMap")
sc.parallelize(Seq(itemStringIntMap))
.saveAsObjectFile(s"/tmp/${id}/itemStringIntMap")
true
}
override def toString = {
s"userFeatures: [${userFeatures.count()}]" +
s"(${userFeatures.take(2).toList}...)" +
s" productFeatures: [${productFeatures.count()}]" +
s"(${productFeatures.take(2).toList}...)" +
s" userStringIntMap: [${userStringIntMap.size}]" +
s"(${userStringIntMap.take(2)}...)" +
s" itemStringIntMap: [${itemStringIntMap.size}]" +
s"(${itemStringIntMap.take(2)}...)"
}
}
object ALSModel
extends PersistentModelLoader[ALSAlgorithmParams, ALSModel] {
def apply(id: String, params: ALSAlgorithmParams,
sc: Option[SparkContext]) = {
new ALSModel(
rank = sc.get.objectFile[Int](s"/tmp/${id}/rank").first,
userFeatures = sc.get.objectFile(s"/tmp/${id}/userFeatures"),
productFeatures = sc.get.objectFile(s"/tmp/${id}/productFeatures"),
userStringIntMap = sc.get
.objectFile[BiMap[String, Int]](s"/tmp/${id}/userStringIntMap").first,
itemStringIntMap = sc.get
.objectFile[BiMap[String, Int]](s"/tmp/${id}/itemStringIntMap").first)
}
}