blob: 53b61b1b1e90a0d050e5a138a87eac4ec45a6b53 [file] [log] [blame]
package io.pivotal.gemfire.spark.connector.internal.rdd
import com.gemstone.gemfire.cache.Region
import io.pivotal.gemfire.spark.connector.GemFireConnectionConf
import org.apache.spark.{TaskContext, Partition}
import org.apache.spark.rdd.RDD
import scala.collection.JavaConversions._
/**
* An `RDD[T, V]` that will represent the result of a join between `left` RDD[T]
* and the specified GemFire Region[K, V].
*/
class GemFireJoinRDD[T, K, V] private[connector]
( left: RDD[T],
func: T => K,
val regionPath: String,
val connConf: GemFireConnectionConf
) extends RDD[(T, V)](left.context, left.dependencies) {
/** validate region existence when GemFireRDD object is created */
validate()
/** Validate region, and make sure it exists. */
private def validate(): Unit = connConf.getConnection.validateRegion[K, V](regionPath)
override protected def getPartitions: Array[Partition] = left.partitions
override def compute(split: Partition, context: TaskContext): Iterator[(T, V)] = {
val region = connConf.getConnection.getRegionProxy[K, V](regionPath)
if (func == null) computeWithoutFunc(split, context, region)
else computeWithFunc(split, context, region)
}
/** T is (K, V1) since there's no map function `func` */
private def computeWithoutFunc(split: Partition, context: TaskContext, region: Region[K, V]): Iterator[(T, V)] = {
val leftPairs = left.iterator(split, context).toList.asInstanceOf[List[(K, _)]]
val leftKeys = leftPairs.map { case (k, v) => k}.toSet
// Note: get all will return (key, null) for non-exist entry, so remove those entries
val rightPairs = region.getAll(leftKeys).filter { case (k, v) => v != null}
leftPairs.filter{case (k, v) => rightPairs.contains(k)}
.map {case (k, v) => ((k, v).asInstanceOf[T], rightPairs.get(k).get)}.toIterator
}
private def computeWithFunc(split: Partition, context: TaskContext, region: Region[K, V]): Iterator[(T, V)] = {
val leftPairs = left.iterator(split, context).toList.map(t => (t, func(t)))
val leftKeys = leftPairs.map { case (t, k) => k}.toSet
// Note: get all will return (key, null) for non-exist entry, so remove those entries
val rightPairs = region.getAll(leftKeys).filter { case (k, v) => v != null}
leftPairs.filter { case (t, k) => rightPairs.contains(k)}.map {case (t, k) => (t, rightPairs.get(k).get)}.toIterator
}
}