blob: cd7519d3cf95fd5032c71383b9902247f3a402e3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.classification
import org.apache.predictionio.controller.PDataSource
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import grizzled.slf4j.Logger
case class DataSourceParams(
appName: String,
evalK: Option[Int] // define the k-fold parameter.
) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, ActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "item", // MODIFIED
// only keep entities with these required properties defined
required = Some(List( // MODIFIED
"featureA", "featureB", "featureC", "featureD", "label")))(sc)
// aggregateProperties() returns RDD pair of
// entity ID and its aggregated properties
.map { case (entityId, properties) =>
try {
// MODIFIED
LabeledPoint(properties.get[Double]("label"),
Vectors.dense(Array(
properties.get[Double]("featureA"),
properties.get[Double]("featureB"),
properties.get[Double]("featureC"),
properties.get[Double]("featureD")
))
)
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" ${entityId}. Exception: ${e}.")
throw e
}
}
}.cache()
new TrainingData(labeledPoints)
}
override
def readEval(sc: SparkContext)
: Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = {
require(dsp.evalK.nonEmpty, "DataSourceParams.evalK must not be None")
// The following code reads the data from data store. It is equivalent to
// the readTraining method. We copy-and-paste the exact code here for
// illustration purpose, a recommended approach is to factor out this logic
// into a helper function and have both readTraining and readEval call the
// helper.
val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "item", // MODIFIED
// only keep entities with these required properties defined
required = Some(List( // MODIFIED
"featureA", "featureB", "featureC", "featureD", "label")))(sc)
// aggregateProperties() returns RDD pair of
// entity ID and its aggregated properties
.map { case (entityId, properties) =>
try {
// MODIFIED
LabeledPoint(properties.get[Double]("label"),
Vectors.dense(Array(
properties.get[Double]("featureA"),
properties.get[Double]("featureB"),
properties.get[Double]("featureC"),
properties.get[Double]("featureD")
))
)
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" ${entityId}. Exception: ${e}.")
throw e
}
}
}.cache()
// End of reading from data store
// K-fold splitting
val evalK = dsp.evalK.get
val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex()
(0 until evalK).map { idx =>
val trainingPoints = indexedPoints.filter(_._2 % evalK != idx).map(_._1)
val testingPoints = indexedPoints.filter(_._2 % evalK == idx).map(_._1)
(
new TrainingData(trainingPoints),
new EmptyEvaluationInfo(),
testingPoints.map {
// MODIFIED
p => (Query(p.features(0), p.features(1), p.features(2), p.features(3)), ActualResult(p.label))
}
)
}
}
}
class TrainingData(
val labeledPoints: RDD[LabeledPoint]
) extends Serializable