| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.predictionio.examples.similarproduct |
| |
| import org.apache.predictionio.controller.PDataSource |
| import org.apache.predictionio.controller.EmptyEvaluationInfo |
| import org.apache.predictionio.controller.EmptyActualResult |
| import org.apache.predictionio.controller.Params |
| import org.apache.predictionio.data.storage.Event |
| import org.apache.predictionio.data.store.PEventStore |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.SparkContext._ |
| import org.apache.spark.rdd.RDD |
| |
| import grizzled.slf4j.Logger |
| |
| case class DataSourceParams(appName: String) extends Params |
| |
| class DataSource(val dsp: DataSourceParams) |
| extends PDataSource[TrainingData, |
| EmptyEvaluationInfo, Query, EmptyActualResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| override |
| def readTraining(sc: SparkContext): TrainingData = { |
| |
| // create a RDD of (entityID, User) |
| val usersRDD: RDD[(String, User)] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "user" |
| )(sc).map { case (entityId, properties) => |
| val user = try { |
| User() |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" user ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| (entityId, user) |
| }.cache() |
| |
| // create a RDD of (entityID, Item) |
| val itemsRDD: RDD[(String, Item)] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "item" |
| )(sc).map { case (entityId, properties) => |
| val item = try { |
| // Assume categories is optional property of item. |
| Item(categories = properties.getOpt[List[String]]("categories")) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" item ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| (entityId, item) |
| }.cache() |
| |
| // get all "user" "rate" "item" events |
| val rateEventsRDD: RDD[RateEvent] = PEventStore.find( // MODIFIED |
| appName = dsp.appName, |
| entityType = Some("user"), |
| eventNames = Some(List("rate")), // MODIFIED |
| // targetEntityType is optional field of an event. |
| targetEntityType = Some(Some("item")))(sc) |
| // eventsDb.find() returns RDD[Event] |
| .map { event => |
| val rateEvent = try { // MODIFIED |
| event.event match { |
| case "rate" => RateEvent( // MODIFIED |
| user = event.entityId, |
| item = event.targetEntityId.get, |
| rating = event.properties.get[Double]("rating"), // ADDED |
| t = event.eventTime.getMillis) |
| case _ => throw new Exception(s"Unexpected event ${event} is read.") |
| } |
| } catch { |
| case e: Exception => { |
| logger.error(s"Cannot convert ${event} to RateEvent." + // MODIFIED |
| s" Exception: ${e}.") |
| throw e |
| } |
| } |
| rateEvent // MODIFIED |
| }.cache() |
| |
| new TrainingData( |
| users = usersRDD, |
| items = itemsRDD, |
| rateEvents = rateEventsRDD // MODIFIED |
| ) |
| } |
| } |
| |
| case class User() |
| |
| case class Item(categories: Option[List[String]]) |
| |
| // MODIFIED |
| case class RateEvent(user: String, item: String, rating: Double, t: Long) |
| |
| class TrainingData( |
| val users: RDD[(String, User)], |
| val items: RDD[(String, Item)], |
| val rateEvents: RDD[RateEvent] // MODIFIED |
| ) extends Serializable { |
| override def toString = { |
| s"users: [${users.count()} (${users.take(2).toList}...)]" + |
| s"items: [${items.count()} (${items.take(2).toList}...)]" + |
| // MODIFIED |
| s"rateEvents: [${rateEvents.count()}] (${rateEvents.take(2).toList}...)" |
| } |
| } |