blob: ba7cf28c2c795ecb428c240a7b1bfcc3102cadf9 [file] [log] [blame]
package org.apache.spark.mllib.recommendation
// This must be the same package as Spark's MatrixFactorizationModel because
// MatrixFactorizationModel's constructor is private and we are using
// its constructor in order to save and load the model
import org.template.recommendation.ALSAlgorithmParams
import org.template.recommendation.User
import org.template.recommendation.Item
import io.prediction.controller.IPersistentModel
import io.prediction.controller.IPersistentModelLoader
import io.prediction.data.storage.EntityMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
class ALSModel(
override val rank: Int,
override val userFeatures: RDD[(Int, Array[Double])],
override val productFeatures: RDD[(Int, Array[Double])],
val users: EntityMap[User],
val items: EntityMap[Item])
extends MatrixFactorizationModel(rank, userFeatures, productFeatures)
with IPersistentModel[ALSAlgorithmParams] {
def save(id: String, params: ALSAlgorithmParams,
sc: SparkContext): Boolean = {
sc.parallelize(Seq(rank)).saveAsObjectFile(s"/tmp/${id}/rank")
userFeatures.saveAsObjectFile(s"/tmp/${id}/userFeatures")
productFeatures.saveAsObjectFile(s"/tmp/${id}/productFeatures")
sc.parallelize(Seq(users))
.saveAsObjectFile(s"/tmp/${id}/users")
sc.parallelize(Seq(items))
.saveAsObjectFile(s"/tmp/${id}/items")
true
}
override def toString = {
s"userFeatures: [${userFeatures.count()}]" +
s"(${userFeatures.take(2).toList}...)" +
s" productFeatures: [${productFeatures.count()}]" +
s"(${productFeatures.take(2).toList}...)" +
s" users: [${users.size}]" +
s"(${users.take(2)}...)" +
s" items: [${items.size}]" +
s"(${items.take(2)}...)"
}
}
object ALSModel
extends IPersistentModelLoader[ALSAlgorithmParams, ALSModel] {
def apply(id: String, params: ALSAlgorithmParams,
sc: Option[SparkContext]) = {
new ALSModel(
rank = sc.get.objectFile[Int](s"/tmp/${id}/rank").first,
userFeatures = sc.get.objectFile(s"/tmp/${id}/userFeatures"),
productFeatures = sc.get.objectFile(s"/tmp/${id}/productFeatures"),
users = sc.get
.objectFile[EntityMap[User]](s"/tmp/${id}/users").first,
items = sc.get
.objectFile[EntityMap[Item]](s"/tmp/${id}/items").first)
}
}