blob: 527d9c2e487d2b56dd1cf6c27de296f6efc0bb9b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.similarproduct
import org.apache.predictionio.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
import grizzled.slf4j.Logger
// ADDED
// Extend original ALSAlgorithm and override train() function to handle
// like and dislike events
class LikeAlgorithm(ap: ALSAlgorithmParams) extends ALSAlgorithm(ap) {
@transient lazy override val logger = Logger[this.type]
override
def train(sc: SparkContext, data: PreparedData): ALSModel = {
require(!data.likeEvents.take(1).isEmpty,
s"likeEvents in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.users.take(1).isEmpty,
s"users in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.items.take(1).isEmpty,
s"items in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
// create User and item's String ID to integer index BiMap
val userStringIntMap = BiMap.stringInt(data.users.keys)
val itemStringIntMap = BiMap.stringInt(data.items.keys)
// collect Item as Map and convert ID to Int index
val items: Map[Int, Item] = data.items.map { case (id, item) =>
(itemStringIntMap(id), item)
}.collectAsMap.toMap
val mllibRatings = data.likeEvents
.map { r =>
// Convert user and item String IDs to Int index for MLlib
val uindex = userStringIntMap.getOrElse(r.user, -1)
val iindex = itemStringIntMap.getOrElse(r.item, -1)
if (uindex == -1)
logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ " to Int index.")
if (iindex == -1)
logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ " to Int index.")
// key is (uindex, iindex) tuple, value is (like, t) tuple
((uindex, iindex), (r.like, r.t))
}.filter { case ((u, i), v) =>
// keep events with valid user and item index
(u != -1) && (i != -1)
}.reduceByKey { case (v1, v2) => // MODIFIED
// An user may like an item and change to dislike it later,
// or vice versa. Use the latest value for this case.
val (like1, t1) = v1
val (like2, t2) = v2
// keep the latest value
if (t1 > t2) v1 else v2
}.map { case ((u, i), (like, t)) => // MODIFIED
// With ALS.trainImplicit(), we can use negative value to indicate
// nagative siginal (ie. dislike)
val r = if (like) 1 else -1
// MLlibRating requires integer index for user and item
MLlibRating(u, i, r)
}
.cache()
// MLLib ALS cannot handle empty training data.
require(!mllibRatings.take(1).isEmpty,
s"mllibRatings cannot be empty." +
" Please check if your events contain valid user and item ID.")
// seed for MLlib ALS
val seed = ap.seed.getOrElse(System.nanoTime)
val m = ALS.trainImplicit(
ratings = mllibRatings,
rank = ap.rank,
iterations = ap.numIterations,
lambda = ap.lambda,
blocks = -1,
alpha = 1.0,
seed = seed)
new ALSModel(
productFeatures = m.productFeatures.collectAsMap.toMap,
itemStringIntMap = itemStringIntMap,
items = items
)
}
}