blob: 44834b3ad3cbdf298cc022d6383009c2e7028393 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.similarproduct
import org.apache.predictionio.controller.PDataSource
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.EmptyActualResult
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.storage.Event
import org.apache.predictionio.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import grizzled.slf4j.Logger
case class DataSourceParams(appName: String) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, EmptyActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
// create a RDD of (entityID, User)
val usersRDD: RDD[(String, User)] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "user"
)(sc).map { case (entityId, properties) =>
val user = try {
User()
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" user ${entityId}. Exception: ${e}.")
throw e
}
}
(entityId, user)
}.cache()
// create a RDD of (entityID, Item)
val itemsRDD: RDD[(String, Item)] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "item"
)(sc).map { case (entityId, properties) =>
val item = try {
// Assume categories is optional property of item.
Item(categories = properties.getOpt[List[String]]("categories"))
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" item ${entityId}. Exception: ${e}.")
throw e
}
}
(entityId, item)
}.cache()
// get all "user" "view" "item" events
val viewEventsRDD: RDD[ViewEvent] = PEventStore.find(
appName = dsp.appName,
entityType = Some("user"),
eventNames = Some(List("view")),
// targetEntityType is optional field of an event.
targetEntityType = Some(Some("item")))(sc)
// eventsDb.find() returns RDD[Event]
.map { event =>
val viewEvent = try {
event.event match {
case "view" => ViewEvent(
user = event.entityId,
item = event.targetEntityId.get,
t = event.eventTime.getMillis)
case _ => throw new Exception(s"Unexpected event ${event} is read.")
}
} catch {
case e: Exception => {
logger.error(s"Cannot convert ${event} to ViewEvent." +
s" Exception: ${e}.")
throw e
}
}
viewEvent
}.cache()
// ADDED
// get all "user" "like" and "dislike" "item" events
val likeEventsRDD: RDD[LikeEvent] = PEventStore.find(
appName = dsp.appName,
entityType = Some("user"),
eventNames = Some(List("like", "dislike")),
// targetEntityType is optional field of an event.
targetEntityType = Some(Some("item")))(sc)
// eventsDb.find() returns RDD[Event]
.map { event =>
val likeEvent = try {
event.event match {
case "like" | "dislike" => LikeEvent(
user = event.entityId,
item = event.targetEntityId.get,
t = event.eventTime.getMillis,
like = (event.event == "like"))
case _ => throw new Exception(s"Unexpected event ${event} is read.")
}
} catch {
case e: Exception => {
logger.error(s"Cannot convert ${event} to LikeEvent." +
s" Exception: ${e}.")
throw e
}
}
likeEvent
}.cache()
new TrainingData(
users = usersRDD,
items = itemsRDD,
viewEvents = viewEventsRDD,
likeEvents = likeEventsRDD // ADDED
)
}
}
case class User()
case class Item(categories: Option[List[String]])
case class ViewEvent(user: String, item: String, t: Long)
case class LikeEvent( // ADDED
user: String,
item: String,
t: Long,
like: Boolean // true: like. false: dislike
)
class TrainingData(
val users: RDD[(String, User)],
val items: RDD[(String, Item)],
val viewEvents: RDD[ViewEvent],
val likeEvents: RDD[LikeEvent] // ADDED
) extends Serializable {
override def toString = {
s"users: [${users.count()} (${users.take(2).toList}...)]" +
s"items: [${items.count()} (${items.take(2).toList}...)]" +
s"viewEvents: [${viewEvents.count()}] (${viewEvents.take(2).toList}...)" +
// ADDED
s"likeEvents: [${likeEvents.count()}] (${likeEvents.take(2).toList}...)"
}
}