blob: 0fc1df4cd1e9b03563c58520ca8ffcaa55059746 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset}
import org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{PythonStreamBlockId, StorageLevel}
case class PythonStreamingSourceOffset(json: String) extends Offset
class PythonMicroBatchStream(
ds: PythonDataSourceV2,
shortName: String,
outputSchema: StructType,
options: CaseInsensitiveStringMap
)
extends MicroBatchStream
with Logging
with AcceptsLatestSeenOffset {
private def createDataSourceFunc =
ds.source.createPythonFunction(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)).dataSource)
private val streamId = nextStreamId
private var nextBlockId = 0L
// planInputPartitions() maybe be called multiple times for the current microbatch.
// Cache the result of planInputPartitions() because it may involve sending data
// from python to JVM.
private var cachedInputPartition: Option[(String, String, PythonStreamingInputPartition)] = None
private val runner: PythonStreamingSourceRunner =
new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
runner.init()
override def initialOffset(): Offset = PythonStreamingSourceOffset(runner.initialOffset())
override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset())
override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = {
val startOffsetJson = start.asInstanceOf[PythonStreamingSourceOffset].json
val endOffsetJson = end.asInstanceOf[PythonStreamingSourceOffset].json
if (cachedInputPartition.exists(p => p._1 == startOffsetJson && p._2 == endOffsetJson)) {
return Array(cachedInputPartition.get._3)
}
val (partitions, rows) = runner.partitions(startOffsetJson, endOffsetJson)
if (rows.isDefined) {
// Only SimpleStreamReader without partitioning prefetch data.
assert(partitions.length == 1)
nextBlockId = nextBlockId + 1
val blockId = PythonStreamBlockId(streamId, nextBlockId)
SparkEnv.get.blockManager.putIterator(
blockId, rows.get, StorageLevel.MEMORY_AND_DISK_SER, true)
val partition = PythonStreamingInputPartition(0, partitions.head, Some(blockId))
cachedInputPartition.foreach(_._3.dropCache())
cachedInputPartition = Some((startOffsetJson, endOffsetJson, partition))
Array(partition)
} else {
partitions.zipWithIndex
.map(p => PythonStreamingInputPartition(p._2, p._1, None))
}
}
override def setLatestSeenOffset(offset: Offset): Unit = {
// Call planPartition on python with an empty offset range to initialize the start offset
// for the prefetching of simple reader.
runner.partitions(offset.json(), offset.json())
}
private lazy val readInfo: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
outputSchema,
isStreaming = true)
}
override def createReaderFactory(): PartitionReaderFactory = {
new PythonStreamingPartitionReaderFactory(
ds.source, readInfo.func, outputSchema, None)
}
override def commit(end: Offset): Unit = {
runner.commit(end.asInstanceOf[PythonStreamingSourceOffset].json)
}
override def stop(): Unit = {
cachedInputPartition.foreach(_._3.dropCache())
runner.stop()
}
override def deserializeOffset(json: String): Offset = PythonStreamingSourceOffset(json)
}
object PythonMicroBatchStream {
private var currentId = 0
def nextStreamId: Int = synchronized {
currentId = currentId + 1
currentId
}
}