| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.mllib.recommendation |
| // This must be the same package as Spark's MatrixFactorizationModel because |
| // MatrixFactorizationModel's constructor is private and we are using |
| // its constructor in order to save and load the model |
| |
| import com.github.fommil.netlib.BLAS.{getInstance => blas} // ADDED |
| import org.apache.predictionio.examples.recommendation.ALSAlgorithmParams |
| |
| import org.apache.predictionio.controller.PersistentModel |
| import org.apache.predictionio.controller.PersistentModelLoader |
| import org.apache.predictionio.data.storage.BiMap |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.SparkContext._ |
| import org.apache.spark.rdd.RDD |
| |
| class ALSModel( |
| override val rank: Int, |
| override val userFeatures: RDD[(Int, Array[Double])], |
| override val productFeatures: RDD[(Int, Array[Double])], |
| val userStringIntMap: BiMap[String, Int], |
| val itemStringIntMap: BiMap[String, Int]) |
| extends MatrixFactorizationModel(rank, userFeatures, productFeatures) |
| with PersistentModel[ALSAlgorithmParams] { |
| |
| // ADDED |
| def recommendProductsWithFilter(user: Int, num: Int, productIdFilter: Set[Int]) = { |
| val filteredProductFeatures = productFeatures |
| .filter { case (id, _) => !productIdFilter.contains(id) } // (*) |
| recommend(userFeatures.lookup(user).head, filteredProductFeatures, num) |
| .map(t => Rating(user, t._1, t._2)) |
| } |
| |
| // ADDED |
| private def recommend( |
| recommendToFeatures: Array[Double], |
| recommendableFeatures: RDD[(Int, Array[Double])], |
| num: Int): Array[(Int, Double)] = { |
| val scored = recommendableFeatures.map { case (id, features) => |
| (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1)) |
| } |
| scored.top(num)(Ordering.by(_._2)) |
| } |
| |
| def save(id: String, params: ALSAlgorithmParams, |
| sc: SparkContext): Boolean = { |
| |
| sc.parallelize(Seq(rank)).saveAsObjectFile(s"/tmp/${id}/rank") |
| userFeatures.saveAsObjectFile(s"/tmp/${id}/userFeatures") |
| productFeatures.saveAsObjectFile(s"/tmp/${id}/productFeatures") |
| sc.parallelize(Seq(userStringIntMap)) |
| .saveAsObjectFile(s"/tmp/${id}/userStringIntMap") |
| sc.parallelize(Seq(itemStringIntMap)) |
| .saveAsObjectFile(s"/tmp/${id}/itemStringIntMap") |
| true |
| } |
| |
| override def toString = { |
| s"userFeatures: [${userFeatures.count()}]" + |
| s"(${userFeatures.take(2).toList}...)" + |
| s" productFeatures: [${productFeatures.count()}]" + |
| s"(${productFeatures.take(2).toList}...)" + |
| s" userStringIntMap: [${userStringIntMap.size}]" + |
| s"(${userStringIntMap.take(2)}...)" + |
| s" itemStringIntMap: [${itemStringIntMap.size}]" + |
| s"(${itemStringIntMap.take(2)}...)" |
| } |
| } |
| |
| object ALSModel |
| extends PersistentModelLoader[ALSAlgorithmParams, ALSModel] { |
| def apply(id: String, params: ALSAlgorithmParams, |
| sc: Option[SparkContext]) = { |
| new ALSModel( |
| rank = sc.get.objectFile[Int](s"/tmp/${id}/rank").first, |
| userFeatures = sc.get.objectFile(s"/tmp/${id}/userFeatures"), |
| productFeatures = sc.get.objectFile(s"/tmp/${id}/productFeatures"), |
| userStringIntMap = sc.get |
| .objectFile[BiMap[String, Int]](s"/tmp/${id}/userStringIntMap").first, |
| itemStringIntMap = sc.get |
| .objectFile[BiMap[String, Int]](s"/tmp/${id}/itemStringIntMap").first) |
| } |
| } |