blob: afe25ce5a0b0e2ebd0e9216fe6bf2c590adc364b [file] [log] [blame]
/** Copyright 2014 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.controller.java
import io.prediction.core.BaseDataSource
import io.prediction.controller.Params
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import java.util.{ List => JList }
import java.util.{ ArrayList => JArrayList }
import java.util.{ Collections => JCollections }
import java.lang.{ Iterable => JIterable }
import scala.collection.JavaConversions._
import scala.reflect._
/**
* Base class of a local data source.
*
* A local data source runs locally within a single machine and return data that
* can fit within a single machine.
*
* @param <DSP> Data Source Parameters
* @param <DP> Data Parameters
* @param <TD> Training Data
* @param <Q> Input Query
* @param <A> Actual Value
*/
abstract class LJavaDataSource[DSP <: Params, DP, TD, Q, A]
extends BaseDataSource[DSP, DP, RDD[TD], Q, A]()(
JavaUtils.fakeClassTag[DSP]) {
def readBase(sc: SparkContext): Seq[(DP, RDD[TD], RDD[(Q, A)])] = {
implicit val fakeTdTag: ClassTag[TD] = JavaUtils.fakeClassTag[TD]
val datasets = sc.parallelize(Array(None)).flatMap(_ => read().toSeq).zipWithIndex
datasets.cache
val dps = datasets.map(t => t._2 -> t._1._1).collect.toMap
dps.map { t =>
val dataset = datasets.filter(_._2 == t._1).map(_._1)
val dp = t._2
val td = dataset.map(_._2)
val qa = dataset.map(_._3.toSeq).flatMap(identity)
(dp, td, qa)
}.toSeq
/*
read().toSeq.map(e =>
(e._1, sc.parallelize(Seq(e._2)), sc.parallelize(e._3.toSeq)))
*/
}
/** Implement this method to only return training data from a data source.
*/
def readTraining(): TD = null.asInstanceOf[TD]
/** Implement this method to return one set of test data (
* an Iterable of query and actual value pairs) from a data source.
* Should also implement readTraining to return correponding training data.
*/
def readTest(): Tuple2[DP, JIterable[Tuple2[Q, A]]] =
(null.asInstanceOf[DP], JCollections.emptyList())
/** Implement this method to return one or more sets of training data
* and test data (an Iterable of query and actual value pairs) from a
* data source.
*/
def read(): JIterable[Tuple3[DP, TD, JIterable[Tuple2[Q, A]]]] = {
val (dp, qa) = readTest()
List((dp, readTraining(), qa))
}
}