blob: 0905ee76c3f340685c9ec339ea694c2278ba1a2e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client
import java.lang.ref.Cleaner
import java.util.Objects
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String,
setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
extends AutoCloseable { self =>
case class StageInfo(
stageId: Long,
numTasks: Long,
completedTasks: Long = 0,
inputBytesRead: Long = 0,
completed: Boolean = false)
object StageInfo {
def apply(stageInfo: proto.ExecutePlanResponse.ExecutionProgress.StageInfo): StageInfo = {
StageInfo(
stageInfo.getStageId,
stageInfo.getNumTasks,
stageInfo.getNumCompletedTasks,
stageInfo.getInputBytesRead,
stageInfo.getDone)
}
}
object Progress {
def apply(progress: proto.ExecutePlanResponse.ExecutionProgress): Progress = {
Progress(
progress.getStagesList.asScala.map(StageInfo(_)).toSeq,
progress.getNumInflightTasks)
}
}
/**
* Progress of the query execution. This information can be accessed from the iterator.
*/
case class Progress(stages: Seq[StageInfo], inflight: Long)
var progress: Progress = new Progress(Seq.empty, 0)
private[this] var opId: String = _
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
private[this] var arrowSchema: pojo.Schema = _
private[this] var nextResultIndex: Int = 0
private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
private val observedMetrics = mutable.Map.empty[String, Row]
private val cleanable =
SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses))
/**
* Update RowEncoder and recursively update the fields of the ProductEncoder if found.
*/
private def createEncoder[E](
enc: AgnosticEncoder[E],
dataType: DataType): AgnosticEncoder[E] = {
enc match {
case UnboundRowEncoder =>
// Replace the row encoder with the encoder inferred from the schema.
RowEncoder
.encoderFor(dataType.asInstanceOf[StructType])
.asInstanceOf[AgnosticEncoder[E]]
case ProductEncoder(clsTag, fields, outer) if ProductEncoder.isTuple(clsTag) =>
// Recursively continue updating the tuple product encoder
val schema = dataType.asInstanceOf[StructType]
assert(fields.length <= schema.fields.length)
val updatedFields = fields.zipWithIndex.map { case (f, id) =>
f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
}
ProductEncoder(clsTag, updatedFields, outer)
case _ =>
enc
}
}
private def processResponses(
stopOnOperationId: Boolean = false,
stopOnSchema: Boolean = false,
stopOnArrowSchema: Boolean = false,
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
var nonEmpty = false
var stop = false
while (!stop && responses.hasNext) {
val response = responses.next()
// Collect metrics for this response
observedMetrics ++= processObservedMetrics(response.getObservedMetricsList)
// Save and validate operationId
if (opId == null) {
opId = response.getOperationId
}
if (opId != response.getOperationId) {
// backwards compatibility:
// response from an old server without operationId field would have getOperationId == "".
throw new IllegalStateException(
"Received response with wrong operationId. " +
s"Expected '$opId' but received '${response.getOperationId}'.")
}
stop |= stopOnOperationId
// Update the execution status. This information can now be accessed directly from
// the iterator.
if (response.hasExecutionProgress) {
progress = Progress(response.getExecutionProgress)
}
if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType]
stop |= stopOnSchema
}
if (response.hasArrowBatch) {
val ipcStreamBytes = response.getArrowBatch.getData
val expectedNumRows = response.getArrowBatch.getRowCount
val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
if (arrowSchema == null) {
arrowSchema = reader.schema
stop |= stopOnArrowSchema
} else if (arrowSchema != reader.schema) {
throw new IllegalStateException(
s"""Schema Mismatch between expected and received schema:
|=== Expected Schema ===
|$arrowSchema
|=== Received Schema ===
|${reader.schema}
|""".stripMargin)
}
if (structType == null) {
// If the schema is not available yet, fallback to the arrow schema.
structType = ArrowUtils.fromArrowSchema(reader.schema)
}
if (response.getArrowBatch.hasStartOffset) {
val expectedStartOffset = response.getArrowBatch.getStartOffset
if (numRecords != expectedStartOffset) {
throw new IllegalStateException(
s"Expected arrow batch to start at row offset $numRecords in results, " +
s"but received arrow batch starting at offset $expectedStartOffset.")
}
}
var numRecordsInBatch = 0
val messages = Seq.newBuilder[ArrowMessage]
while (reader.hasNext) {
val message = reader.next()
message match {
case batch: ArrowRecordBatch =>
numRecordsInBatch += batch.getLength
case _ =>
}
messages += message
}
if (numRecordsInBatch != expectedNumRows) {
throw new IllegalStateException(
s"Expected $expectedNumRows rows in arrow batch but got $numRecordsInBatch.")
}
// Skip the entire result if it is empty.
if (numRecordsInBatch > 0) {
numRecords += numRecordsInBatch
resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
nextResultIndex += 1
nonEmpty |= true
stop |= stopOnFirstNonEmptyResponse
}
}
}
nonEmpty
}
private def processObservedMetrics(
metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val keys = mutable.ListBuffer.empty[String]
val values = mutable.ListBuffer.empty[Any]
(0 until metric.getKeysCount).map { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
keys += key
values += value
}
// If the metrics is registered by an Observation object, attach them and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
}
metric.getName -> new GenericRowWithSchema(values.toArray, schema)
}
}
/**
* Returns the number of elements in the result.
*/
def length: Int = {
// We need to process all responses to make sure numRecords is correct.
processResponses()
numRecords
}
/**
* @return
* the schema of the result.
*/
def schema: StructType = {
if (structType == null) {
processResponses(stopOnSchema = true)
}
structType
}
/**
* @return
* the operationId of the result.
*/
def operationId: String = {
if (opId == null) {
processResponses(stopOnOperationId = true)
}
opId
}
/**
* Create an Array with the contents of the result.
*/
def toArray: Array[T] = {
val result = encoder.clsTag.newArray(length)
val rows = iterator
try {
var i = 0
while (rows.hasNext) {
result(i) = rows.next()
assert(i < numRecords)
i += 1
}
} finally {
rows.close()
}
result
}
/**
* Returns all observed metrics in the result.
*/
def getObservedMetrics: Map[String, Row] = {
// We need to process all responses to get all metrics.
processResponses()
observedMetrics.toMap
}
/**
* Returns an iterator over the contents of the result.
*/
def iterator: CloseableIterator[T] =
buildIterator(destructive = false)
/**
* Returns an destructive iterator over the contents of the result.
*/
def destructiveIterator: CloseableIterator[T] =
buildIterator(destructive = true)
private def buildIterator(destructive: Boolean): CloseableIterator[T] = {
new CloseableIterator[T] {
private[this] var iter: CloseableIterator[T] = _
private def initialize(): Unit = {
if (iter == null) {
iter = new ArrowDeserializingIterator(
createEncoder(encoder, schema),
new ConcatenatingArrowStreamReader(
allocator,
Iterator.single(new ResultMessageIterator(destructive)),
destructive),
timeZoneId)
}
}
override def hasNext: Boolean = {
initialize()
iter.hasNext
}
override def next(): T = {
initialize()
iter.next()
}
override def close(): Unit = {
if (iter != null) {
iter.close()
}
}
}
}
/**
* Close this result, freeing any underlying resources.
*/
override def close(): Unit = cleanable.clean()
private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator {
private[this] var totalBytesRead = 0L
private[this] var nextResultIndex = 0
private[this] var current: Iterator[ArrowMessage] = Iterator.empty
override def bytesRead: Long = totalBytesRead
override def schema: pojo.Schema = {
if (arrowSchema == null) {
// We need a schema to proceed. Spark Connect will always
// return a result (with a schema) even if the result is empty.
processResponses(stopOnArrowSchema = true)
Objects.requireNonNull(arrowSchema)
}
arrowSchema
}
override def hasNext: Boolean = {
if (current.hasNext) {
return true
}
val hasNextResult = if (!resultMap.contains(nextResultIndex)) {
self.processResponses(stopOnFirstNonEmptyResponse = true)
} else {
true
}
if (hasNextResult) {
val Some((sizeInBytes, messages)) = if (destructive) {
resultMap.remove(nextResultIndex)
} else {
resultMap.get(nextResultIndex)
}
totalBytesRead += sizeInBytes
current = messages.iterator
nextResultIndex += 1
}
hasNextResult
}
override def next(): ArrowMessage = {
if (!hasNext) {
throw new NoSuchElementException()
}
current.next()
}
}
}
private object SparkResult {
private val cleaner: Cleaner = Cleaner.create()
}
private[client] class SparkResultCloseable(
resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])],
responses: CloseableIterator[proto.ExecutePlanResponse])
extends AutoCloseable
with Runnable {
override def close(): Unit = {
resultMap.values.foreach(_._2.foreach(_.close()))
responses.close()
}
override def run(): Unit = {
close()
}
}