blob: b391203b4b968c3f17c0e8f477976a98856aa017 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis
import java.util.concurrent.TimeUnit
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import com.amazonaws.auth.AWSCredentials
import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord
import com.amazonaws.services.kinesis.model._
import org.apache.spark._
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{ERROR, NUM_RETRY}
import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.NextIterator
/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */
private[kinesis]
case class SequenceNumberRange(
streamName: String,
shardId: String,
fromSeqNumber: String,
toSeqNumber: String,
recordCount: Int)
/** Class representing an array of Kinesis sequence number ranges */
private[kinesis]
case class SequenceNumberRanges(ranges: Seq[SequenceNumberRange]) {
def isEmpty(): Boolean = ranges.isEmpty
def nonEmpty(): Boolean = ranges.nonEmpty
override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")")
}
private[kinesis]
object SequenceNumberRanges {
def apply(range: SequenceNumberRange): SequenceNumberRanges = {
new SequenceNumberRanges(Seq(range))
}
}
/** Partition storing the information of the ranges of Kinesis sequence numbers to read */
private[kinesis]
class KinesisBackedBlockRDDPartition(
idx: Int,
blockId: BlockId,
val isBlockIdValid: Boolean,
val seqNumberRanges: SequenceNumberRanges
) extends BlockRDDPartition(blockId, idx)
/**
* A BlockRDD where the block data is backed by Kinesis, which can accessed using the
* sequence numbers of the corresponding blocks.
*/
private[kinesis]
class KinesisBackedBlockRDD[T: ClassTag](
sc: SparkContext,
val regionName: String,
val endpointUrl: String,
@transient private val _blockIds: Array[BlockId],
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient private val isBlockIdValid: Array[Boolean] = Array.empty,
val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _,
val kinesisCreds: SparkAWSCredentials = DefaultCredentials,
val kinesisReadConfigs: KinesisReadConfigurations = KinesisReadConfigurations()
) extends BlockRDD[T](sc, _blockIds) {
require(_blockIds.length == arrayOfseqNumberRanges.length,
"Number of blockIds is not equal to the number of sequence number ranges")
override def isValid: Boolean = true
override def getPartitions: Array[Partition] = {
Array.tabulate(_blockIds.length) { i =>
val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i)
new KinesisBackedBlockRDDPartition(i, _blockIds(i), isValid, arrayOfseqNumberRanges(i))
}
}
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
val blockId = partition.blockId
def getBlockFromBlockManager(): Option[Iterator[T]] = {
logDebug(s"Read partition data of $this from block manager, block $blockId")
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
}
def getBlockFromKinesis(): Iterator[T] = {
val credentials = kinesisCreds.provider.getCredentials
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, kinesisReadConfigs).map(messageHandler)
}
}
if (partition.isBlockIdValid) {
getBlockFromBlockManager().getOrElse { getBlockFromKinesis() }
} else {
getBlockFromKinesis()
}
}
}
/**
* An iterator that return the Kinesis data based on the given range of sequence numbers.
* Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber,
* until the endSequenceNumber is reached.
*/
private[kinesis]
class KinesisSequenceRangeIterator(
credentials: AWSCredentials,
endpointUrl: String,
regionId: String,
range: SequenceNumberRange,
kinesisReadConfigs: KinesisReadConfigurations) extends NextIterator[Record] with Logging {
private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
private val shardId = range.shardId
// AWS limits to maximum of 10k records per get call
private val maxGetRecordsLimit = 10000
private var toSeqNumberReceived = false
private var lastSeqNumber: String = null
private var internalIterator: Iterator[Record] = null
client.setEndpoint(endpointUrl)
override protected def getNext(): Record = {
var nextRecord: Record = null
if (toSeqNumberReceived) {
finished = true
} else {
if (internalIterator == null) {
// If the internal iterator has not been initialized,
// then fetch records from starting sequence number
internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber,
range.recordCount)
} else if (!internalIterator.hasNext) {
// If the internal iterator does not have any more records,
// then fetch more records after the last consumed sequence number
internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber,
range.recordCount)
}
if (!internalIterator.hasNext) {
// If the internal iterator still does not have any data, then throw exception
// and terminate this iterator
finished = true
throw new SparkException(
s"Could not read until the end sequence number of the range: $range")
} else {
// Get the record, copy the data into a byte array and remember its sequence number
nextRecord = internalIterator.next()
lastSeqNumber = nextRecord.getSequenceNumber()
// If the this record's sequence number matches the stopping sequence number, then make sure
// the iterator is marked finished next time getNext() is called
if (nextRecord.getSequenceNumber == range.toSeqNumber) {
toSeqNumberReceived = true
}
}
}
nextRecord
}
override protected def close(): Unit = {
client.shutdown()
}
/**
* Get records starting from or after the given sequence number.
*/
private def getRecords(
iteratorType: ShardIteratorType,
seqNum: String,
recordCount: Int): Iterator[Record] = {
val shardIterator = getKinesisIterator(iteratorType, seqNum)
val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount)
result._1
}
/**
* Get the records starting from using a Kinesis shard iterator (which is a progress handle
* to get records from Kinesis), and get the next shard iterator for next consumption.
*/
private def getRecordsAndNextKinesisIterator(
shardIterator: String,
recordCount: Int): (Iterator[Record], String) = {
val getRecordsRequest = new GetRecordsRequest
getRecordsRequest.setRequestCredentials(credentials)
getRecordsRequest.setShardIterator(shardIterator)
getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit))
val getRecordsResult = retryOrTimeout[GetRecordsResult](
s"getting records using shard iterator") {
client.getRecords(getRecordsRequest)
}
// De-aggregate records, if KPL was used in producing the records. The KCL automatically
// handles de-aggregation during regular operation. This code path is used during recovery
val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords)
(recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator)
}
/**
* Get the Kinesis shard iterator for getting records starting from or after the given
* sequence number.
*/
private def getKinesisIterator(
iteratorType: ShardIteratorType,
sequenceNumber: String): String = {
val getShardIteratorRequest = new GetShardIteratorRequest
getShardIteratorRequest.setRequestCredentials(credentials)
getShardIteratorRequest.setStreamName(streamName)
getShardIteratorRequest.setShardId(shardId)
getShardIteratorRequest.setShardIteratorType(iteratorType.toString)
getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber)
val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult](
s"getting shard iterator from sequence number $sequenceNumber") {
client.getShardIterator(getShardIteratorRequest)
}
getShardIteratorResult.getShardIterator
}
/** Helper method to retry Kinesis API request with exponential backoff and timeouts */
private def retryOrTimeout[T](message: String)(body: => T): T = {
val startTimeNs = System.nanoTime()
var retryCount = 0
var result: Option[T] = None
var lastError: Throwable = null
var waitTimeInterval = kinesisReadConfigs.retryWaitTimeMs
def isTimedOut = {
val retryTimeoutNs = TimeUnit.MILLISECONDS.toNanos(kinesisReadConfigs.retryTimeoutMs)
(System.nanoTime() - startTimeNs) >= retryTimeoutNs
}
def isMaxRetryDone = retryCount >= kinesisReadConfigs.maxRetries
while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
if (retryCount > 0) { // wait only if this is a retry
Thread.sleep(waitTimeInterval)
waitTimeInterval *= 2 // if you have waited, then double wait time for next round
}
try {
result = Some(body)
} catch {
case NonFatal(t) =>
lastError = t
t match {
case ptee: ProvisionedThroughputExceededException =>
logWarning(log"Error while ${MDC(ERROR, message)} " +
log"[attempt = ${MDC(NUM_RETRY, retryCount + 1)}]", ptee)
case e: Throwable =>
throw new SparkException(s"Error while $message", e)
}
}
retryCount += 1
}
result.getOrElse {
if (isTimedOut) {
throw new SparkException(
s"Timed out after ${kinesisReadConfigs.retryTimeoutMs} ms while " +
s"$message, last exception: ", lastError)
} else {
throw new SparkException(
s"Gave up after $retryCount retries while $message, last exception: ", lastError)
}
}
}
}