| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.connect.execution |
| |
| import scala.concurrent.{ExecutionContext, Promise} |
| import scala.jdk.CollectionConverters._ |
| import scala.util.Try |
| import scala.util.control.NonFatal |
| |
| import com.google.protobuf.Message |
| import org.apache.commons.lang3.StringUtils |
| |
| import org.apache.spark.SparkSQLException |
| import org.apache.spark.connect.proto |
| import org.apache.spark.internal.Logging |
| import org.apache.spark.sql.connect.common.ProtoUtils |
| import org.apache.spark.sql.connect.planner.SparkConnectPlanner |
| import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} |
| import org.apache.spark.sql.connect.utils.ErrorUtils |
| import org.apache.spark.util.{ThreadUtils, Utils} |
| |
| /** |
| * This class launches the actual execution in an execution thread. The execution pushes the |
| * responses to a ExecuteResponseObserver in executeHolder. |
| */ |
| private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { |
| |
| private val promise: Promise[Unit] = Promise[Unit]() |
| |
| // The newly created thread will inherit all InheritableThreadLocals used by Spark, |
| // e.g. SparkContext.localProperties. If considering implementing a thread-pool, |
| // forwarding of thread locals needs to be taken into account. |
| private val executionThread: ExecutionThread = new ExecutionThread(promise) |
| |
| private var started: Boolean = false |
| |
| private var interrupted: Boolean = false |
| |
| private var completed: Boolean = false |
| |
| private val lock = new Object |
| |
| /** Launches the execution in a background thread, returns immediately. */ |
| private[connect] def start(): Unit = { |
| lock.synchronized { |
| assert(!started) |
| // Do not start if already interrupted. |
| if (!interrupted) { |
| executionThread.start() |
| started = true |
| } |
| } |
| } |
| |
| /** |
| * Register a callback that gets executed after completion/interruption of the execution thread. |
| */ |
| private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { |
| promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) |
| } |
| |
| /** |
| * Interrupt the executing thread. |
| * @return |
| * true if it was not interrupted before, false if it was already interrupted or completed. |
| */ |
| private[connect] def interrupt(): Boolean = { |
| lock.synchronized { |
| if (!started && !interrupted) { |
| // execution thread hasn't started yet, and will not be started. |
| // handle the interrupted error here directly. |
| interrupted = true |
| ErrorUtils.handleError( |
| "execute", |
| executeHolder.responseObserver, |
| executeHolder.sessionHolder.userId, |
| executeHolder.sessionHolder.sessionId, |
| Some(executeHolder.eventsManager), |
| interrupted)(new SparkSQLException("OPERATION_CANCELED", Map.empty)) |
| true |
| } else if (!interrupted && !completed) { |
| // checking completed prevents sending interrupt onError after onCompleted |
| interrupted = true |
| executionThread.interrupt() |
| true |
| } else { |
| false |
| } |
| } |
| } |
| |
| private def execute(): Unit = { |
| // Outer execute handles errors. |
| // Separate it from executeInternal to save on indent and improve readability. |
| try { |
| try { |
| executeInternal() |
| } catch { |
| // Need to catch throwable instead of NonFatal, because e.g. InterruptedException is fatal. |
| case e: Throwable => |
| logDebug(s"Exception in execute: $e") |
| // Always cancel all remaining execution after error. |
| executeHolder.sessionHolder.session.sparkContext.cancelJobsWithTag(executeHolder.jobTag) |
| // Rely on an internal interrupted flag, because Thread.interrupted() could be cleared, |
| // and different exceptions like InterruptedException, ClosedByInterruptException etc. |
| // could be thrown. |
| if (interrupted) { |
| throw new SparkSQLException("OPERATION_CANCELED", Map.empty) |
| } else { |
| // Rethrown the original error. |
| throw e |
| } |
| } finally { |
| executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) |
| SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag)) |
| executeHolder.sparkSessionTags.foreach { tag => |
| executeHolder.sessionHolder.session.sparkContext.removeJobTag( |
| ExecuteSessionTag( |
| executeHolder.sessionHolder.userId, |
| executeHolder.sessionHolder.sessionId, |
| tag)) |
| } |
| } |
| } catch { |
| ErrorUtils.handleError( |
| "execute", |
| executeHolder.responseObserver, |
| executeHolder.sessionHolder.userId, |
| executeHolder.sessionHolder.sessionId, |
| Some(executeHolder.eventsManager), |
| interrupted) |
| } |
| } |
| |
| // Inner executeInternal is wrapped by execute() for error handling. |
| private def executeInternal() = { |
| // synchronized - check if already got interrupted while starting. |
| lock.synchronized { |
| if (interrupted) { |
| throw new InterruptedException() |
| } |
| } |
| |
| // `withSession` ensures that session-specific artifacts (such as JARs and class files) are |
| // available during processing. |
| executeHolder.sessionHolder.withSession { session => |
| val debugString = requestString(executeHolder.request) |
| |
| // Set tag for query cancellation |
| session.sparkContext.addJobTag(executeHolder.jobTag) |
| // Register the job for progress reports. |
| SparkConnectService.executionListener.foreach(_.registerJobTag(executeHolder.jobTag)) |
| // Also set all user defined tags as Spark Job tags. |
| executeHolder.sparkSessionTags.foreach { tag => |
| session.sparkContext.addJobTag( |
| ExecuteSessionTag( |
| executeHolder.sessionHolder.userId, |
| executeHolder.sessionHolder.sessionId, |
| tag)) |
| } |
| session.sparkContext.setJobDescription( |
| s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") |
| session.sparkContext.setInterruptOnCancel(true) |
| |
| // Add debug information to the query execution so that the jobs are traceable. |
| session.sparkContext.setLocalProperty( |
| "callSite.short", |
| s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") |
| session.sparkContext.setLocalProperty( |
| "callSite.long", |
| StringUtils.abbreviate(debugString, 2048)) |
| |
| executeHolder.request.getPlan.getOpTypeCase match { |
| case proto.Plan.OpTypeCase.COMMAND => handleCommand(executeHolder.request) |
| case proto.Plan.OpTypeCase.ROOT => handlePlan(executeHolder.request) |
| case _ => |
| throw new UnsupportedOperationException( |
| s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") |
| } |
| |
| val observedMetrics: Map[String, Seq[(Option[String], Any)]] = { |
| executeHolder.observations.map { case (name, observation) => |
| val values = observation.getOrEmpty.map { case (key, value) => |
| (Some(key), value) |
| }.toSeq |
| name -> values |
| }.toMap |
| } |
| val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = { |
| executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator => |
| accumulator.synchronized { |
| val value = accumulator.value.asScala.toSeq |
| if (value.nonEmpty) { |
| accumulator.reset() |
| Some("__python_accumulator__" -> value.map(value => (None, value))) |
| } else { |
| None |
| } |
| } |
| }.toMap |
| } |
| if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) { |
| executeHolder.responseObserver.onNext( |
| SparkConnectPlanExecution |
| .createObservedMetricsResponse( |
| executeHolder.sessionHolder.sessionId, |
| executeHolder.sessionHolder.serverSessionId, |
| executeHolder.request.getPlan.getRoot.getCommon.getPlanId, |
| observedMetrics ++ accumulatedInPython)) |
| } |
| |
| lock.synchronized { |
| // Synchronized before sending ResultComplete, and up until completing the result stream |
| // to prevent a situation in which a client of reattachable execution receives |
| // ResultComplete, and proceeds to send ReleaseExecute, and that triggers an interrupt |
| // before it finishes. |
| |
| if (interrupted) { |
| // check if it got interrupted at the very last moment |
| throw new InterruptedException() |
| } |
| completed = true // no longer interruptible |
| |
| if (executeHolder.reattachable) { |
| // Reattachable execution sends a ResultComplete at the end of the stream |
| // to signal that there isn't more coming. |
| executeHolder.responseObserver.onNextComplete(createResultComplete()) |
| } else { |
| executeHolder.responseObserver.onCompleted() |
| } |
| } |
| } |
| } |
| |
| private def handlePlan(request: proto.ExecutePlanRequest): Unit = { |
| val responseObserver = executeHolder.responseObserver |
| |
| val execution = new SparkConnectPlanExecution(executeHolder) |
| execution.handlePlan(responseObserver) |
| } |
| |
| private def handleCommand(request: proto.ExecutePlanRequest): Unit = { |
| val responseObserver = executeHolder.responseObserver |
| |
| val command = request.getPlan.getCommand |
| val planner = new SparkConnectPlanner(executeHolder) |
| planner.process(command = command, responseObserver = responseObserver) |
| } |
| |
| private def requestString(request: Message) = { |
| try { |
| Utils.redact( |
| executeHolder.sessionHolder.session.sessionState.conf.stringRedactionPattern, |
| ProtoUtils.abbreviate(request).toString) |
| } catch { |
| case NonFatal(e) => |
| logWarning("Fail to extract debug information", e) |
| "UNKNOWN" |
| } |
| } |
| |
| private def createResultComplete(): proto.ExecutePlanResponse = { |
| // Send the Spark data type |
| proto.ExecutePlanResponse |
| .newBuilder() |
| .setResultComplete(proto.ExecutePlanResponse.ResultComplete.newBuilder().build()) |
| .build() |
| } |
| |
| private class ExecutionThread(onCompletionPromise: Promise[Unit]) |
| extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { |
| override def run(): Unit = { |
| try { |
| execute() |
| onCompletionPromise.success(()) |
| } catch { |
| case NonFatal(e) => |
| onCompletionPromise.failure(e) |
| } |
| } |
| } |
| } |
| |
| private[connect] object ExecuteThreadRunner { |
| private implicit val namedExecutionContext: ExecutionContext = ExecutionContext |
| .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) |
| } |