| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.predictionio.examples.classification |
| |
| import org.apache.predictionio.controller.PDataSource |
| import org.apache.predictionio.controller.EmptyEvaluationInfo |
| import org.apache.predictionio.controller.Params |
| import org.apache.predictionio.data.store.PEventStore |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.mllib.regression.LabeledPoint |
| import org.apache.spark.mllib.linalg.Vectors |
| |
| import grizzled.slf4j.Logger |
| |
| case class DataSourceParams( |
| appName: String, |
| evalK: Option[Int] // define the k-fold parameter. |
| ) extends Params |
| |
| class DataSource(val dsp: DataSourceParams) |
| extends PDataSource[TrainingData, |
| EmptyEvaluationInfo, Query, ActualResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| override |
| def readTraining(sc: SparkContext): TrainingData = { |
| |
| val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "item", // MODIFIED |
| // only keep entities with these required properties defined |
| required = Some(List( // MODIFIED |
| "featureA", "featureB", "featureC", "featureD", "label")))(sc) |
| // aggregateProperties() returns RDD pair of |
| // entity ID and its aggregated properties |
| .map { case (entityId, properties) => |
| try { |
| // MODIFIED |
| LabeledPoint(properties.get[Double]("label"), |
| Vectors.dense(Array( |
| properties.get[Double]("featureA"), |
| properties.get[Double]("featureB"), |
| properties.get[Double]("featureC"), |
| properties.get[Double]("featureD") |
| )) |
| ) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| }.cache() |
| |
| new TrainingData(labeledPoints) |
| } |
| |
| override |
| def readEval(sc: SparkContext) |
| : Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = { |
| require(dsp.evalK.nonEmpty, "DataSourceParams.evalK must not be None") |
| |
| // The following code reads the data from data store. It is equivalent to |
| // the readTraining method. We copy-and-paste the exact code here for |
| // illustration purpose, a recommended approach is to factor out this logic |
| // into a helper function and have both readTraining and readEval call the |
| // helper. |
| val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "item", // MODIFIED |
| // only keep entities with these required properties defined |
| required = Some(List( // MODIFIED |
| "featureA", "featureB", "featureC", "featureD", "label")))(sc) |
| // aggregateProperties() returns RDD pair of |
| // entity ID and its aggregated properties |
| .map { case (entityId, properties) => |
| try { |
| // MODIFIED |
| LabeledPoint(properties.get[Double]("label"), |
| Vectors.dense(Array( |
| properties.get[Double]("featureA"), |
| properties.get[Double]("featureB"), |
| properties.get[Double]("featureC"), |
| properties.get[Double]("featureD") |
| )) |
| ) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| }.cache() |
| // End of reading from data store |
| |
| // K-fold splitting |
| val evalK = dsp.evalK.get |
| val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex() |
| |
| (0 until evalK).map { idx => |
| val trainingPoints = indexedPoints.filter(_._2 % evalK != idx).map(_._1) |
| val testingPoints = indexedPoints.filter(_._2 % evalK == idx).map(_._1) |
| |
| ( |
| new TrainingData(trainingPoints), |
| new EmptyEvaluationInfo(), |
| testingPoints.map { |
| // MODIFIED |
| p => (Query(p.features(0), p.features(1), p.features(2), p.features(3)), ActualResult(p.label)) |
| } |
| ) |
| } |
| } |
| } |
| |
| class TrainingData( |
| val labeledPoints: RDD[LabeledPoint] |
| ) extends Serializable |