blob: 9ada54e2054366906bd5765f3614ae659152e9be [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.util.UUID
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import scala.util.Try
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, SparkEnv, SparkUnsupportedOperationException}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Base trait for a versioned key-value store which provides read operations. Each instance of a
* `ReadStateStore` represents a specific version of state data, and such instances are created
* through a [[StateStoreProvider]].
*
* `abort` method will be called when the task is completed - please clean up the resources in
* the method.
*
* IMPLEMENTATION NOTES:
* * The implementation can throw exception on calling prefixScan method if the functionality is
* not supported yet from the implementation. Note that some stateful operations would not work
* on disabling prefixScan functionality.
*/
trait ReadStateStore {
/** Unique identifier of the store */
def id: StateStoreId
/** Version of the data in this store before committing updates. */
def version: Long
/**
* Get the current value of a non-null key.
* @return a non-null row if the key exists in the store, otherwise null.
*/
def get(
key: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow
/**
* Provides an iterator containing all values of a non-null key. If key does not exist,
* an empty iterator is returned. Implementations should make sure to return an empty
* iterator if the key does not exist.
*
* It is expected to throw exception if Spark calls this method without setting
* multipleValuesPerKey as true for the column family.
*/
def valuesIterator(key: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow]
/**
* Return an iterator containing all the key-value pairs which are matched with
* the given prefix key.
*
* The operator will provide numColsPrefixKey greater than 0 in StateStoreProvider.init method
* if the operator needs to leverage the "prefix scan" feature. The schema of the prefix key
* should be same with the leftmost `numColsPrefixKey` columns of the key schema.
*
* It is expected to throw exception if Spark calls this method without setting numColsPrefixKey
* to the greater than 0.
*/
def prefixScan(
prefixKey: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
/** Return an iterator containing all the key-value pairs in the StateStore. */
def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
/**
* Clean up the resource.
*
* The method name is to respect backward compatibility on [[StateStore]].
*/
def abort(): Unit
}
/**
* Base trait for a versioned key-value store which provides both read and write operations. Each
* instance of a `StateStore` represents a specific version of state data, and such instances are
* created through a [[StateStoreProvider]].
*
* Unlike [[ReadStateStore]], `abort` method may not be called if the `commit` method succeeds
* to commit the change. (`hasCommitted` returns `true`.) Otherwise, `abort` method will be called.
* Implementation should deal with resource cleanup in both methods, and also need to guard with
* double resource cleanup.
*/
trait StateStore extends ReadStateStore {
/**
* Remove column family with given name, if present.
*/
def removeColFamilyIfExists(colFamilyName: String): Boolean
/**
* Create column family with given name, if absent.
*/
def createColFamilyIfAbsent(
colFamilyName: String,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
useMultipleValuesPerKey: Boolean = false,
isInternal: Boolean = false): Unit
/**
* Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows
* in the params can be reused, and must make copies of the data as needed for persistence.
*/
def put(
key: UnsafeRow,
value: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
/**
* Remove a single non-null key.
*/
def remove(
key: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
/**
* Merges the provided value with existing values of a non-null key. If a existing
* value does not exist, this operation behaves as [[StateStore.put()]].
*
* It is expected to throw exception if Spark calls this method without setting
* multipleValuesPerKey as true for the column family.
*/
def merge(key: UnsafeRow, value: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
/**
* Commit all the updates that have been made to the store, and return the new version.
* Implementations should ensure that no more updates (puts, removes) can be after a commit in
* order to avoid incorrect usage.
*/
def commit(): Long
/**
* Abort all the updates that have been made to the store. Implementations should ensure that
* no more updates (puts, removes) can be after an abort in order to avoid incorrect usage.
*/
override def abort(): Unit
/**
* Return an iterator containing all the key-value pairs in the StateStore. Implementations must
* ensure that updates (puts, removes) can be made while iterating over this iterator.
*
* It is not required for implementations to ensure the iterator reflects all updates being
* performed after initialization of the iterator. Callers should perform all updates before
* calling this method if all updates should be visible in the returned iterator.
*/
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
Iterator[UnsafeRowPair]
/** Current metrics of the state store */
def metrics: StateStoreMetrics
/**
* Whether all updates have been committed
*/
def hasCommitted: Boolean
}
/** Wraps the instance of StateStore to make the instance read-only. */
class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
override def id: StateStoreId = store.id
override def version: Long = store.version
override def get(key: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = store.get(key,
colFamilyName)
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
Iterator[UnsafeRowPair] = store.iterator(colFamilyName)
override def abort(): Unit = store.abort()
override def prefixScan(prefixKey: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
store.prefixScan(prefixKey, colFamilyName)
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
store.valuesIterator(key, colFamilyName)
}
}
/**
* Metrics reported by a state store
* @param numKeys Number of keys in the state store
* @param memoryUsedBytes Memory used by the state store
* @param customMetrics Custom implementation-specific metrics
* The metrics reported through this must have the same `name` as those
* reported by `StateStoreProvider.customMetrics`.
*/
case class StateStoreMetrics(
numKeys: Long,
memoryUsedBytes: Long,
customMetrics: Map[StateStoreCustomMetric, Long])
object StateStoreMetrics {
def combine(allMetrics: Seq[StateStoreMetrics]): StateStoreMetrics = {
val distinctCustomMetrics = allMetrics.flatMap(_.customMetrics.keys).distinct
val customMetrics = allMetrics.flatMap(_.customMetrics)
val combinedCustomMetrics = distinctCustomMetrics.map { customMetric =>
val sameMetrics = customMetrics.filter(_._1 == customMetric)
val sumOfMetrics = sameMetrics.map(_._2).sum
customMetric -> sumOfMetrics
}.toMap
StateStoreMetrics(
allMetrics.map(_.numKeys).sum,
allMetrics.map(_.memoryUsedBytes).sum,
combinedCustomMetrics)
}
}
/**
* Name and description of custom implementation-specific metrics that a
* state store may wish to expose. Also provides [[SQLMetric]] instance to
* show the metric in UI and accumulate it at the query level.
*/
trait StateStoreCustomMetric {
def name: String
def desc: String
def withNewDesc(desc: String): StateStoreCustomMetric
def createSQLMetric(sparkContext: SparkContext): SQLMetric
}
case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(newDesc: String): StateStoreCustomSumMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createMetric(sparkContext, desc)
}
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(desc: String): StateStoreCustomSizeMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createSizeMetric(sparkContext, desc)
}
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(desc: String): StateStoreCustomTimingMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createTimingMetric(sparkContext, desc)
}
/**
* An exception thrown when an invalid UnsafeRow is detected in state store.
*/
class InvalidUnsafeRowException(error: String)
extends RuntimeException("The streaming query failed by state format invalidation. " +
"The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " +
"incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " +
"among restart. For the first case, you can try to restart the application without " +
s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null)
sealed trait KeyStateEncoderSpec
case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec
case class PrefixKeyScanStateEncoderSpec(
keySchema: StructType,
numColsPrefixKey: Int) extends KeyStateEncoderSpec {
if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) {
throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString)
}
}
/** Encodes rows so that they can be range-scanned based on orderingOrdinals */
case class RangeKeyScanStateEncoderSpec(
keySchema: StructType,
orderingOrdinals: Seq[Int]) extends KeyStateEncoderSpec {
if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) {
throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString)
}
}
/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
*
* The life cycle of a provider and its provide stores are as follows.
*
* - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when
* the first batch of a streaming query is executed on the executor. All subsequent batches reuse
* this provider instance until the query is stopped.
*
* - Every batch of streaming data request a specific version of the state data by invoking
* `getStore(version)` which returns an instance of [[StateStore]] through which the required
* version of the data can be accessed. It is the responsible of the provider to populate
* this store with context information like the schema of keys and values, etc.
*
* - After the streaming query is stopped, the created provider instances are lazily disposed off.
*/
trait StateStoreProvider {
/**
* Initialize the provide with more contextual information from the SQL operator.
* This method will be called first after creating an instance of the StateStoreProvider by
* reflection.
*
* @param stateStoreId Id of the versioned StateStores that this provider will generate
* @param keySchema Schema of keys to be stored
* @param valueSchema Schema of value to be stored
* @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
* A value not greater than 0 means the operator doesn't activate prefix
* key, and the operator should not call prefixScan method in StateStore.
* @param useColumnFamilies Whether the underlying state store uses a single or multiple column
* families
* @param storeConfs Configurations used by the StateStores
* @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
* @param useMultipleValuesPerKey Whether the underlying state store needs to support multiple
* values for a single key.
*/
def init(
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): Unit
/**
* Return the id of the StateStores this provider will generate.
* Should be the same as the one passed in init().
*/
def stateStoreId: StateStoreId
/** Called when the provider instance is unloaded from the executor */
def close(): Unit
/** Return an instance of [[StateStore]] representing state data of the given version */
def getStore(version: Long): StateStore
/**
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* By default it will return the same instance as getStore(version) but wrapped to prevent
* modification. Providers can override and return optimized version of [[ReadStateStore]]
* based on the fact the instance will be only used for reading.
*/
def getReadStore(version: Long): ReadStateStore =
new WrappedReadStateStore(getStore(version))
/** Optional method for providers to allow for background maintenance (e.g. compactions) */
def doMaintenance(): Unit = { }
/**
* Optional custom metrics that the implementation may want to report.
* @note The StateStore objects created by this provider must report the same custom metrics
* (specifically, same names) through `StateStore.metrics`.
*/
def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
}
object StateStoreProvider {
/**
* Return a instance of the given provider class name. The instance will not be initialized.
*/
def create(providerClassName: String): StateStoreProvider = {
val providerClass = Utils.classForName(providerClassName)
providerClass.getConstructor().newInstance().asInstanceOf[StateStoreProvider]
}
/**
* Return a instance of the required provider, initialized with the given configurations.
*/
def createAndInit(
providerId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean): StateStoreProvider = {
val provider = create(storeConf.providerClass)
provider.init(providerId.storeId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
provider
}
/**
* Use the expected schema to check whether the UnsafeRow is valid.
*/
def validateStateRowFormat(
keyRow: UnsafeRow,
keySchema: StructType,
valueRow: UnsafeRow,
valueSchema: StructType,
conf: StateStoreConf): Unit = {
if (conf.formatValidationEnabled) {
val validationError = UnsafeRowUtils.validateStructuralIntegrityWithReason(keyRow, keySchema)
validationError.foreach { error => throw new InvalidUnsafeRowException(error) }
if (conf.formatValidationCheckValue) {
val validationError =
UnsafeRowUtils.validateStructuralIntegrityWithReason(valueRow, valueSchema)
validationError.foreach { error => throw new InvalidUnsafeRowException(error) }
}
}
}
}
/**
* Unique identifier for a provider, used to identify when providers can be reused.
* Note that `queryRunId` is used uniquely identify a provider, so that the same provider
* instance is not reused across query restarts.
*/
case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID)
object StateStoreProviderId {
private[sql] def apply(
stateInfo: StatefulOperatorStateInfo,
partitionIndex: Int,
storeName: String): StateStoreProviderId = {
val storeId = StateStoreId(
stateInfo.checkpointLocation, stateInfo.operatorId, partitionIndex, storeName)
StateStoreProviderId(storeId, stateInfo.queryRunId)
}
}
/**
* Unique identifier for a bunch of keyed state data.
* @param checkpointRootLocation Root directory where all the state data of a query is stored
* @param operatorId Unique id of a stateful operator
* @param partitionId Index of the partition of an operators state data
* @param storeName Optional, name of the store. Each partition can optionally use multiple state
* stores, but they have to be identified by distinct names.
*/
case class StateStoreId(
checkpointRootLocation: String,
operatorId: Long,
partitionId: Int,
storeName: String = StateStoreId.DEFAULT_STORE_NAME) {
/**
* Checkpoint directory to be used by a single state store, identified uniquely by the tuple
* (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should
* use this path for saving state data, as this ensures that distinct stores will write to
* different locations.
*/
def storeCheckpointLocation(): Path = {
if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
// For reading state store data that was generated before store names were used (Spark <= 2.2)
new Path(checkpointRootLocation, s"$operatorId/$partitionId")
} else {
new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName")
}
}
override def toString: String = {
s"""StateStoreId[ checkpointRootLocation=$checkpointRootLocation, operatorId=$operatorId,
| partitionId=$partitionId, storeName=$storeName ]
|""".stripMargin.replaceAll("\n", "")
}
}
object StateStoreId {
val DEFAULT_STORE_NAME = "default"
}
/** Mutable, and reusable class for representing a pair of UnsafeRows. */
class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = {
this.key = key
this.value = value
this
}
}
/**
* Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
* by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
* it also runs a periodic background task to do maintenance on the loaded stores. For each
* store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
* the store is the active instance. Accordingly, it either keeps it loaded and performs
* maintenance, or unloads the store.
*/
object StateStore extends Logging {
val PARTITION_ID_TO_CHECK_SCHEMA = 0
val DEFAULT_COL_FAMILY_NAME = "default"
@GuardedBy("loadedProviders")
private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]()
@GuardedBy("loadedProviders")
private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]]()
private val maintenanceThreadPoolLock = new Object
// Shared exception between threads in thread pool that the scheduling thread
// checks to see if an exception has been thrown in the maintenance task
private val threadPoolException = new AtomicReference[Throwable](null)
// This set is to keep track of the partitions that are queued
// for maintenance or currently have maintenance running on them
// to prevent the same partition from being processed concurrently.
@GuardedBy("maintenanceThreadPoolLock")
private val maintenancePartitions = new mutable.HashSet[StateStoreProviderId]
/**
* Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
* will be called when an exception happens.
*/
class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) {
private val executor =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
private val runnable = new Runnable {
override def run(): Unit = {
try {
task
} catch {
case NonFatal(e) =>
logWarning("Error running maintenance thread", e)
onError
throw e
}
}
}
private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate(
runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
def stop(): Unit = {
future.cancel(false)
executor.shutdown()
}
def isRunning: Boolean = !future.isDone
}
/**
* Thread Pool that runs maintenance on partitions that are scheduled by
* MaintenanceTask periodically
*/
class MaintenanceThreadPool(numThreads: Int) {
private val threadPool = ThreadUtils.newDaemonFixedThreadPool(
numThreads, "state-store-maintenance-thread")
def execute(runnable: Runnable): Unit = {
threadPool.execute(runnable)
}
def stop(): Unit = {
threadPool.shutdown()
}
}
@GuardedBy("loadedProviders")
private var maintenanceTask: MaintenanceTask = null
@GuardedBy("loadedProviders")
private var maintenanceThreadPool: MaintenanceThreadPool = null
@GuardedBy("loadedProviders")
private var _coordRef: StateStoreCoordinatorRef = null
/** Get or create a read-only store associated with the id. */
def getReadOnly(
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
version: Long,
useColumnFamilies: Boolean,
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): ReadStateStore = {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
storeProvider.getReadStore(version)
}
/** Get or create a store associated with the id. */
def get(
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
version: Long,
useColumnFamilies: Boolean,
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): StateStore = {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
storeProvider.getStore(version)
}
private def disallowBinaryInequalityColumn(schema: StructType): Unit = {
if (!UnsafeRowUtils.isBinaryStable(schema)) {
throw new SparkUnsupportedOperationException(
errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY",
messageParameters = Map("schema" -> schema.json)
)
}
}
private def getStateStoreProvider(
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean): StateStoreProvider = {
loadedProviders.synchronized {
startMaintenanceIfNeeded(storeConf)
if (storeProviderId.storeId.partitionId == PARTITION_ID_TO_CHECK_SCHEMA) {
val result = schemaValidated.getOrElseUpdate(storeProviderId, {
// SPARK-47776: collation introduces the concept of binary (in)equality, which means
// in some collation we no longer be able to just compare the binary format of two
// UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically"
// same in case insensitive collation.
// State store is basically key-value storage, and the most provider implementations
// rely on the fact that all the columns in the key schema support binary equality.
// We need to disallow using binary inequality column in the key schema, before we
// could support this in majority of state store providers (or high-level of state
// store.)
disallowBinaryInequalityColumn(keySchema)
val checker = new StateSchemaCompatibilityChecker(storeProviderId, hadoopConf)
// regardless of configuration, we check compatibility to at least write schema file
// if necessary
// if the format validation for value schema is disabled, we also disable the schema
// compatibility checker for value schema as well.
val ret = Try(
checker.check(keySchema, valueSchema,
ignoreValueSchema = !storeConf.formatValidationCheckValue)
).toEither.fold(Some(_), _ => None)
if (storeConf.stateSchemaCheckEnabled) {
ret
} else {
None
}
})
if (result.isDefined) {
throw result.get
}
}
// SPARK-42567 - Track load time for state store provider and log warning if takes longer
// than 2s.
val (provider, loadTimeMs) = Utils.timeTakenMs {
loadedProviders.getOrElseUpdate(
storeProviderId,
StateStoreProvider.createAndInit(
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
)
}
if (loadTimeMs > 2000L) {
logWarning(log"Loaded state store provider in loadTimeMs=${MDC(LOAD_TIME, loadTimeMs)} " +
log"for storeId=${MDC(STORE_ID, storeProviderId.storeId.toString)} and " +
log"queryRunId=${MDC(QUERY_RUN_ID, storeProviderId.queryRunId)}")
}
val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
providerIdsToUnload.foreach(unload(_))
provider
}
}
/** Unload a state store provider */
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach(_.close())
}
/** Unload all state store providers: unit test purpose */
private[sql] def unloadAll(): Unit = loadedProviders.synchronized {
loadedProviders.keySet.foreach { key => unload(key) }
loadedProviders.clear()
}
/** Whether a state store provider is loaded or not */
def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized {
loadedProviders.contains(storeProviderId)
}
/** Check if maintenance thread is running and scheduled future is not done */
def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
maintenanceTask != null && maintenanceTask.isRunning
}
/** Stop maintenance thread and reset the maintenance task */
def stopMaintenanceTask(): Unit = loadedProviders.synchronized {
if (maintenanceThreadPool != null) {
threadPoolException.set(null)
maintenanceThreadPoolLock.synchronized {
maintenancePartitions.clear()
}
maintenanceThreadPool.stop()
maintenanceThreadPool = null
}
if (maintenanceTask != null) {
maintenanceTask.stop()
maintenanceTask = null
}
}
/** Unload and stop all state store providers */
def stop(): Unit = loadedProviders.synchronized {
loadedProviders.keySet.foreach { key => unload(key) }
loadedProviders.clear()
_coordRef = null
stopMaintenanceTask()
logInfo("StateStore stopped")
}
/** Start the periodic maintenance task if not already started and if Spark active */
private def startMaintenanceIfNeeded(storeConf: StateStoreConf): Unit = {
val numMaintenanceThreads = storeConf.numStateStoreMaintenanceThreads
loadedProviders.synchronized {
if (SparkEnv.get != null && !isMaintenanceRunning) {
maintenanceTask = new MaintenanceTask(
storeConf.maintenanceInterval,
task = { doMaintenance() },
onError = { loadedProviders.synchronized {
logInfo("Stopping maintenance task since an error was encountered.")
stopMaintenanceTask()
// SPARK-44504 - Unload explicitly to force closing underlying DB instance
// and releasing allocated resources, especially for RocksDBStateStoreProvider.
loadedProviders.keySet.foreach { key => unload(key) }
loadedProviders.clear()
}
}
)
maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads)
logInfo("State Store maintenance task started")
}
}
}
private def processThisPartition(id: StateStoreProviderId): Boolean = {
maintenanceThreadPoolLock.synchronized {
if (!maintenancePartitions.contains(id)) {
maintenancePartitions.add(id)
true
} else {
false
}
}
}
/**
* Execute background maintenance task in all the loaded store providers if they are still
* the active instances according to the coordinator.
*/
private def doMaintenance(): Unit = {
logDebug("Doing maintenance")
if (SparkEnv.get == null) {
throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores")
}
loadedProviders.synchronized {
loadedProviders.toSeq
}.foreach { case (id, provider) =>
// check exception
if (threadPoolException.get() != null) {
val exception = threadPoolException.get()
logWarning("Error in maintenanceThreadPool", exception)
throw exception
}
if (processThisPartition(id)) {
maintenanceThreadPool.execute(() => {
val startTime = System.currentTimeMillis()
try {
provider.doMaintenance()
if (!verifyIfStoreInstanceActive(id)) {
unload(id)
logInfo(log"Unloaded ${MDC(STATE_STORE_PROVIDER, provider)}")
}
} catch {
case NonFatal(e) =>
logWarning(log"Error managing ${MDC(STATE_STORE_PROVIDER, provider)}, " +
log"stopping management thread", e)
threadPoolException.set(e)
} finally {
val duration = System.currentTimeMillis() - startTime
val logMsg =
log"Finished maintenance task for provider=${MDC(STATE_STORE_PROVIDER, id)}" +
log" in elapsed_time=${MDC(TIME_UNITS, duration)}\n"
if (duration > 5000) {
logInfo(logMsg)
} else {
logDebug(logMsg)
}
maintenanceThreadPoolLock.synchronized {
maintenancePartitions.remove(id)
}
}
})
} else {
logInfo(log"Not processing partition ${MDC(PARTITION_ID, id)} " +
log"for maintenance because it is currently " +
log"being processed")
}
}
}
private def reportActiveStoreInstance(
storeProviderId: StateStoreProviderId,
otherProviderIds: Seq[StateStoreProviderId]): Seq[StateStoreProviderId] = {
if (SparkEnv.get != null) {
val host = SparkEnv.get.blockManager.blockManagerId.host
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
val providerIdsToUnload = coordinatorRef
.map(_.reportActiveInstance(storeProviderId, host, executorId, otherProviderIds))
.getOrElse(Seq.empty[StateStoreProviderId])
logInfo(log"Reported that the loaded instance " +
log"${MDC(STATE_STORE_PROVIDER, storeProviderId)} is active")
logDebug(log"The loaded instances are going to unload: " +
log"${MDC(STATE_STORE_PROVIDER, providerIdsToUnload.mkString(", "))}")
providerIdsToUnload
} else {
Seq.empty[StateStoreProviderId]
}
}
private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = {
if (SparkEnv.get != null) {
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
val verified =
coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false)
logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified")
verified
} else {
false
}
}
private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
val env = SparkEnv.get
if (env != null) {
val isDriver =
env.executorId == SparkContext.DRIVER_IDENTIFIER
// If running locally, then the coordinator reference in _coordRef may be have become inactive
// as SparkContext + SparkEnv may have been restarted. Hence, when running in driver,
// always recreate the reference.
if (isDriver || _coordRef == null) {
logDebug("Getting StateStoreCoordinatorRef")
_coordRef = StateStoreCoordinatorRef.forExecutor(env)
}
logInfo(log"Retrieved reference to StateStoreCoordinator: " +
log"${MDC(STATE_STORE_PROVIDER, _coordRef)}")
Some(_coordRef)
} else {
_coordRef = null
None
}
}
}