blob: 0b549870a3482429578857a92379a5d992fb431e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.TimeoutException
import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, OffsetOutOfRangeException}
import org.apache.kafka.common.TopicPartition
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A [[ContinuousStream]] for data from kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
* @param kafkaParams String params for per-task Kafka consumers.
* @param options Params which are not Kafka consumer params.
* @param metadataPath Path to a directory this reader can use for writing metadata.
* @param initialOffsets The Kafka offsets to start reading data at.
* @param failOnDataLoss Flag indicating whether reading should fail in data loss
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
class KafkaContinuousStream(
private[kafka010] val offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousStream with Logging {
private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
// Exposed outside this object only for unit tests.
@volatile private[sql] var knownPartitions: Set[TopicPartition] = _
override def initialOffset(): Offset = {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets(None))
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
case SpecificTimestampRangeLimit(p) => offsetReader.fetchSpecificTimestampBasedOffsets(p,
failsOnNoMatchingOffset = true)
}
logInfo(s"Initial offsets: $offsets")
offsets
}
override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
override def planInputPartitions(start: Offset): Array[InputPartition] = {
val oldStartPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
reportDataLoss(message)
}
val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
knownPartitions = startOffsets.keySet
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}
override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
KafkaContinuousReaderFactory
}
/** Stop this source and free any resources it has allocated. */
def stop(): Unit = synchronized {
offsetReader.close()
}
override def commit(end: Offset): Unit = {}
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
val mergedMap = offsets.map {
case KafkaSourcePartitionOffset(p, o) => Map(p -> o)
}.reduce(_ ++ _)
KafkaSourceOffset(mergedMap)
}
override def needsReconfiguration(): Boolean = {
offsetReader.fetchLatestOffsets(None).keySet != knownPartitions
}
override def toString(): String = s"KafkaSource[$offsetReader]"
/**
* If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
* Otherwise, just log a warning.
*/
private def reportDataLoss(message: String): Unit = {
if (failOnDataLoss) {
throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
} else {
logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
}
}
}
/**
* An input partition for continuous Kafka processing. This will be serialized and transformed
* into a full reader on executors.
*
* @param topicPartition The (topic, partition) pair this task is responsible for.
* @param startOffset The offset to start reading from within the partition.
* @param kafkaParams Kafka consumer params to use.
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
* @param includeHeaders Flag indicating whether to include Kafka records' headers.
*/
case class KafkaContinuousInputPartition(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition
object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader(
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
p.failOnDataLoss, p.includeHeaders)
}
}
/**
* A per-task data reader for continuous Kafka processing.
*
* @param topicPartition The (topic, partition) pair this data reader is responsible for.
* @param startOffset The offset to start reading from within the partition.
* @param kafkaParams Kafka consumer params to use.
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
class KafkaContinuousPartitionReader(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)
private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
override def next(): Boolean = {
var r: ConsumerRecord[Array[Byte], Array[Byte]] = null
while (r == null) {
if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false
// Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving
// interrupt points to end the query rather than waiting for new data that might never come.
try {
r = consumer.get(
nextKafkaOffset,
untilOffset = Long.MaxValue,
pollTimeoutMs,
failOnDataLoss)
} catch {
// We didn't read within the timeout. We're supposed to block indefinitely for new data, so
// swallow and ignore this.
case _: TimeoutException | _: org.apache.kafka.common.errors.TimeoutException =>
// This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range,
// or if it's the endpoint of the data range (i.e. the "true" next offset).
case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] =>
val range = consumer.getAvailableOffsetRange()
if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) {
// retry
} else {
throw e
}
}
}
nextKafkaOffset = r.offset + 1
currentRecord = r
true
}
override def get(): UnsafeRow = {
unsafeRowProjector(currentRecord)
}
override def getOffset(): KafkaSourcePartitionOffset = {
KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
}
override def close(): Unit = {
consumer.release()
}
}