blob: 15131f715da3c0be1c7a31ee5991149562394e6c [file] [log] [blame]
package pio.refactor
import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import grizzled.slf4j.Logger
case class AlgorithmParams(mult: Int) extends Params
class Algorithm(val ap: AlgorithmParams)
// extends PAlgorithm if Model contains RDD[]
extends P2LAlgorithm[TrainingData, Model, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(data: TrainingData): Model = {
// Simply count number of events
// and multiple it by the algorithm parameter
// and store the number as model
val count = data.events.reduce(_ + _) * ap.mult
logger.error("Algo.train")
new Model(mc = count)
}
def predict(model: Model, query: Query): PredictedResult = {
// Prefix the query with the model data
//val result = s"${model.mc}-${query.q}"
logger.error("Algo.predict")
PredictedResult(p = model.mc + query.q)
}
}
class Model(val mc: Int) extends Serializable {
override def toString = s"mc=${mc}"
}