blob: 898858dbc59217a7febabefa5daa26a910d9312e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.recommendation
// This must be the same package as Spark's MatrixFactorizationModel because
// MatrixFactorizationModel's constructor is private and we are using
// its constructor in order to save and load the model
import org.apache.predictionio.examples.recommendation.ALSAlgorithmParams
import org.apache.predictionio.controller.PersistentModel
import org.apache.predictionio.controller.PersistentModelLoader
import org.apache.predictionio.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
class ALSModel(
override val rank: Int,
override val userFeatures: RDD[(Int, Array[Double])],
override val productFeatures: RDD[(Int, Array[Double])],
val userStringIntMap: BiMap[String, Int],
val itemStringIntMap: BiMap[String, Int])
extends MatrixFactorizationModel(rank, userFeatures, productFeatures)
with PersistentModel[ALSAlgorithmParams] {
def save(id: String, params: ALSAlgorithmParams,
sc: SparkContext): Boolean = {
sc.parallelize(Seq(rank)).saveAsObjectFile(s"/tmp/${id}/rank")
userFeatures.saveAsObjectFile(s"/tmp/${id}/userFeatures")
productFeatures.saveAsObjectFile(s"/tmp/${id}/productFeatures")
sc.parallelize(Seq(userStringIntMap))
.saveAsObjectFile(s"/tmp/${id}/userStringIntMap")
sc.parallelize(Seq(itemStringIntMap))
.saveAsObjectFile(s"/tmp/${id}/itemStringIntMap")
true
}
override def toString = {
s"userFeatures: [${userFeatures.count()}]" +
s"(${userFeatures.take(2).toList}...)" +
s" productFeatures: [${productFeatures.count()}]" +
s"(${productFeatures.take(2).toList}...)" +
s" userStringIntMap: [${userStringIntMap.size}]" +
s"(${userStringIntMap.take(2)}...)" +
s" itemStringIntMap: [${itemStringIntMap.size}]" +
s"(${itemStringIntMap.take(2)}...)"
}
}
object ALSModel
extends PersistentModelLoader[ALSAlgorithmParams, ALSModel] {
def apply(id: String, params: ALSAlgorithmParams,
sc: Option[SparkContext]) = {
new ALSModel(
rank = sc.get.objectFile[Int](s"/tmp/${id}/rank").first,
userFeatures = sc.get.objectFile(s"/tmp/${id}/userFeatures"),
productFeatures = sc.get.objectFile(s"/tmp/${id}/productFeatures"),
userStringIntMap = sc.get
.objectFile[BiMap[String, Int]](s"/tmp/${id}/userStringIntMap").first,
itemStringIntMap = sc.get
.objectFile[BiMap[String, Int]](s"/tmp/${id}/itemStringIntMap").first)
}
}