blob: a78acb58fdc20dd2d30d7754d74c3f1b327a0d5e [file] [log] [blame]
/** Copyright 2015 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.controller
import io.prediction.core.BaseAlgorithm
import io.prediction.workflow.PersistentModelManifest
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import scala.reflect._
/** Base class of a parallel algorithm.
*
* A parallel algorithm can be run in parallel on a cluster and produces a
* model that can also be distributed across a cluster.
*
* If your input query class requires custom JSON4S serialization, the most
* idiomatic way is to implement a trait that extends [[CustomQuerySerializer]],
* and mix that into your algorithm class, instead of overriding
* [[querySerializer]] directly.
*
* To provide evaluation feature, one must override and implement the
* [[batchPredict]] method. Otherwise, an exception will be thrown when pio eval`
* is used.
*
* @tparam PD Prepared data class.
* @tparam M Trained model class.
* @tparam Q Input query class.
* @tparam P Output prediction class.
* @group Algorithm
*/
abstract class PAlgorithm[PD, M, Q : Manifest, P]
extends BaseAlgorithm[PD, M, Q, P] {
/** Do not use directly or override this method, as this is called by
* PredictionIO workflow to train a model.
*/
private[prediction]
def trainBase(sc: SparkContext, pd: PD): M = train(sc, pd)
/** Implement this method to produce a model from prepared data.
*
* @param pd Prepared data for model training.
* @return Trained model.
*/
def train(sc: SparkContext, pd: PD): M
private[prediction]
def batchPredictBase(sc: SparkContext, bm: Any, qs: RDD[(Long, Q)])
: RDD[(Long, P)] = batchPredict(bm.asInstanceOf[M], qs)
/** To provide evaluation feature, one must override and implement this method
* to generate many predictions in batch. Otherwise, an exception will be
* thrown when `pio eval` is used.
*
* The default implementation throws an exception.
*
* @param m Trained model produced by [[train]].
* @param qs An RDD of index-query tuples. The index is used to keep track of
* predicted results with corresponding queries.
*/
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
throw new NotImplementedError("batchPredict not implemented")
/** Do not use directly or override this method, as this is called by
* PredictionIO workflow to perform prediction.
*/
private[prediction]
def predictBase(baseModel: Any, query: Q): P = {
predict(baseModel.asInstanceOf[M], query)
}
/** Implement this method to produce a prediction from a query and trained
* model.
*
* @param model Trained model produced by [[train]].
* @param query An input query.
* @return A prediction.
*/
def predict(model: M, query: Q): P
private[prediction]
override
def makePersistentModel(
sc: SparkContext,
modelId: String,
algoParams: Params,
bm: Any): Any = {
// In general, a parallel model may contain multiple RDDs. It is not easy to
// infer and persist them programmatically since these RDDs may be
// potentially huge.
// Persist is successful only if the model is an instance of
// IPersistentModel and save is successful.
val m = bm.asInstanceOf[M]
if (m.isInstanceOf[PersistentModel[_]]) {
if (m.asInstanceOf[PersistentModel[Params]].save(
modelId, algoParams, sc)) {
PersistentModelManifest(className = m.getClass.getName)
} else {
Unit
}
} else {
Unit
}
}
}