blob: 346983c12174728f5822307ecdd89f007b34ae2a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import java.util.ArrayList
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark._
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.spark.datasources.HBaseResources._
import org.apache.hadoop.hbase.util.ShutdownHookManager
import org.apache.spark.{SparkEnv, TaskContext, Partition}
import org.apache.spark.rdd.RDD
import scala.collection.mutable
@InterfaceAudience.Private
class HBaseTableScanRDD(relation: HBaseRelation,
val hbaseContext: HBaseContext,
@transient val filter: Option[SparkSQLPushDownFilter] = None,
val columns: Seq[Field] = Seq.empty
) extends RDD[Result](relation.sqlContext.sparkContext, Nil)
{
private def sparkConf = SparkEnv.get.conf
@transient var ranges = Seq.empty[Range]
@transient var points = Seq.empty[Array[Byte]]
def addPoint(p: Array[Byte]) {
points :+= p
}
def addRange(r: ScanRange) = {
val lower = if (r.lowerBound != null && r.lowerBound.length > 0) {
Some(Bound(r.lowerBound, r.isLowerBoundEqualTo))
} else {
None
}
val upper = if (r.upperBound != null && r.upperBound.length > 0) {
if (!r.isUpperBoundEqualTo) {
Some(Bound(r.upperBound, false))
} else {
// HBase stopRow is exclusive: therefore it DOESN'T act like isUpperBoundEqualTo
// by default. So we need to add a new max byte to the stopRow key
val newArray = new Array[Byte](r.upperBound.length + 1)
System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
//New Max Bytes
newArray(r.upperBound.length) = ByteMin
Some(Bound(newArray, false))
}
} else {
None
}
ranges :+= Range(lower, upper)
}
override def getPartitions: Array[Partition] = {
val regions = RegionResource(relation)
var idx = 0
logDebug(s"There are ${regions.size} regions")
val ps = regions.flatMap { x =>
val rs = Ranges.and(Range(x), ranges)
val ps = Points.and(Range(x), points)
if (rs.size > 0 || ps.size > 0) {
if(log.isDebugEnabled) {
rs.foreach(x => logDebug(x.toString))
}
idx += 1
Some(HBaseScanPartition(idx - 1, x, rs, ps, SerializedFilter.toSerializedTypedFilter(filter)))
} else {
None
}
}.toArray
regions.release()
ShutdownHookManager.affixShutdownHook( new Thread() {
override def run() {
HBaseConnectionCache.close()
}
}, 0)
ps.asInstanceOf[Array[Partition]]
}
override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[HBaseScanPartition].regions.server.map {
identity
}.toSeq
}
private def buildGets(
tbr: TableResource,
g: Seq[Array[Byte]],
filter: Option[SparkSQLPushDownFilter],
columns: Seq[Field],
hbaseContext: HBaseContext): Iterator[Result] = {
g.grouped(relation.bulkGetSize).flatMap{ x =>
val gets = new ArrayList[Get](x.size)
x.foreach{ y =>
val g = new Get(y)
handleTimeSemantics(g)
columns.foreach { d =>
if (!d.isRowKey) {
g.addColumn(d.cfBytes, d.colBytes)
}
}
filter.foreach(g.setFilter(_))
gets.add(g)
}
hbaseContext.applyCreds()
val tmp = tbr.get(gets)
rddResources.addResource(tmp)
toResultIterator(tmp)
}
}
private def toResultIterator(result: GetResource): Iterator[Result] = {
val iterator = new Iterator[Result] {
var idx = 0
var cur: Option[Result] = None
override def hasNext: Boolean = {
while(idx < result.length && cur.isEmpty) {
val r = result(idx)
idx += 1
if (!r.isEmpty) {
cur = Some(r)
}
}
if (cur.isEmpty) {
rddResources.release(result)
}
cur.isDefined
}
override def next(): Result = {
hasNext
val ret = cur.get
cur = None
ret
}
}
iterator
}
private def buildScan(range: Range,
filter: Option[SparkSQLPushDownFilter],
columns: Seq[Field]): Scan = {
val scan = (range.lower, range.upper) match {
case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c)
case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c)
case (Some(Bound(a, b)), None) => new Scan(a)
case (None, None) => new Scan()
}
handleTimeSemantics(scan)
columns.foreach { d =>
if (!d.isRowKey) {
scan.addColumn(d.cfBytes, d.colBytes)
}
}
scan.setCacheBlocks(relation.blockCacheEnable)
scan.setBatch(relation.batchNum)
scan.setCaching(relation.cacheSize)
filter.foreach(scan.setFilter(_))
scan
}
private def toResultIterator(scanner: ScanResource): Iterator[Result] = {
val iterator = new Iterator[Result] {
var cur: Option[Result] = None
override def hasNext: Boolean = {
if (cur.isEmpty) {
val r = scanner.next()
if (r == null) {
rddResources.release(scanner)
} else {
cur = Some(r)
}
}
cur.isDefined
}
override def next(): Result = {
hasNext
val ret = cur.get
cur = None
ret
}
}
iterator
}
lazy val rddResources = RDDResources(new mutable.HashSet[Resource]())
private def close() {
rddResources.release()
}
override def compute(split: Partition, context: TaskContext): Iterator[Result] = {
val partition = split.asInstanceOf[HBaseScanPartition]
val filter = SerializedFilter.fromSerializedFilter(partition.sf)
val scans = partition.scanRanges
.map(buildScan(_, filter, columns))
val tableResource = TableResource(relation)
context.addTaskCompletionListener[Unit](context => close())
val points = partition.points
val gIt: Iterator[Result] = {
if (points.isEmpty) {
Iterator.empty: Iterator[Result]
} else {
buildGets(tableResource, points, filter, columns, hbaseContext)
}
}
val rIts = scans.par
.map { scan =>
hbaseContext.applyCreds()
val scanner = tableResource.getScanner(scan)
rddResources.addResource(scanner)
scanner
}.map(toResultIterator(_))
.fold(Iterator.empty: Iterator[Result]){ case (x, y) =>
x ++ y
} ++ gIt
ShutdownHookManager.affixShutdownHook( new Thread() {
override def run() {
HBaseConnectionCache.close()
}
}, 0)
rIts
}
private def handleTimeSemantics(query: Query): Unit = {
// Set timestamp related values if present
(query, relation.timestamp, relation.minTimestamp, relation.maxTimestamp) match {
case (q: Scan, Some(ts), None, None) => q.setTimeStamp(ts)
case (q: Get, Some(ts), None, None) => q.setTimeStamp(ts)
case (q:Scan, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp)
case (q:Get, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp)
case (q, None, None, None) =>
case _ => throw new IllegalArgumentException(s"Invalid combination of query/timestamp/time range provided. " +
s"timeStamp is: ${relation.timestamp.get}, minTimeStamp is: ${relation.minTimestamp.get}, " +
s"maxTimeStamp is: ${relation.maxTimestamp.get}")
}
if (relation.maxVersions.isDefined) {
query match {
case q: Scan => q.setMaxVersions(relation.maxVersions.get)
case q: Get => q.setMaxVersions(relation.maxVersions.get)
case _ => throw new IllegalArgumentException("Invalid query provided with maxVersions")
}
}
}
}
@InterfaceAudience.Private
case class SerializedFilter(b: Option[Array[Byte]])
object SerializedFilter {
def toSerializedTypedFilter(f: Option[SparkSQLPushDownFilter]): SerializedFilter = {
SerializedFilter(f.map(_.toByteArray))
}
def fromSerializedFilter(sf: SerializedFilter): Option[SparkSQLPushDownFilter] = {
sf.b.map(SparkSQLPushDownFilter.parseFrom(_))
}
}
@InterfaceAudience.Private
private[hbase] case class HBaseRegion(
override val index: Int,
val start: Option[HBaseType] = None,
val end: Option[HBaseType] = None,
val server: Option[String] = None) extends Partition
@InterfaceAudience.Private
private[hbase] case class HBaseScanPartition(
override val index: Int,
val regions: HBaseRegion,
val scanRanges: Seq[Range],
val points: Seq[Array[Byte]],
val sf: SerializedFilter) extends Partition
@InterfaceAudience.Private
case class RDDResources(set: mutable.HashSet[Resource]) {
def addResource(s: Resource) {
set += s
}
def release() {
set.foreach(release(_))
}
def release(rs: Resource) {
try {
rs.release()
} finally {
set.remove(rs)
}
}
}