blob: 65b480993636c8f3a8d9e7ae64ef09c1f440fb55 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
from pyspark.errors import StreamingQueryException, PySparkValueError
import pyspark.sql.connect.proto as pb2
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect import proto
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.streaming import StreamingQueryListener
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
StreamingQueryManager as PySparkStreamingQueryManager,
)
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkPicklingError
if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
class StreamingQuery:
def __init__(
self, session: "SparkSession", queryId: str, runId: str, name: Optional[str] = None
) -> None:
self._session = session
self._query_id = queryId
self._run_id = runId
self._name = name
@property
def id(self) -> str:
return self._query_id
id.__doc__ = PySparkStreamingQuery.id.__doc__
@property
def runId(self) -> str:
return self._run_id
runId.__doc__ = PySparkStreamingQuery.runId.__doc__
@property
def name(self) -> Optional[str]:
return self._name
name.__doc__ = PySparkStreamingQuery.name.__doc__
@property
def isActive(self) -> bool:
return self._fetch_status().is_active
isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]:
cmd = pb2.StreamingQueryCommand()
if timeout is not None:
if not isinstance(timeout, (int, float)) or timeout <= 0:
raise PySparkValueError(
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "timeout", "arg_value": type(timeout).__name__},
)
cmd.await_termination.timeout_ms = int(timeout * 1000)
terminated = self._execute_streaming_query_cmd(cmd).await_termination.terminated
return terminated
else:
await_termination_cmd = pb2.StreamingQueryCommand.AwaitTerminationCommand()
cmd.await_termination.CopyFrom(await_termination_cmd)
self._execute_streaming_query_cmd(cmd)
return None
awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
@property
def status(self) -> Dict[str, Any]:
proto = self._fetch_status()
return {
"message": proto.status_message,
"isDataAvailable": proto.is_data_available,
"isTriggerActive": proto.is_trigger_active,
}
status.__doc__ = PySparkStreamingQuery.status.__doc__
@property
def recentProgress(self) -> List[Dict[str, Any]]:
cmd = pb2.StreamingQueryCommand()
cmd.recent_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
return [json.loads(p) for p in progress]
recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
cmd = pb2.StreamingQueryCommand()
cmd.last_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
return json.loads(progress[-1])
else:
return None
lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
def processAllAvailable(self) -> None:
cmd = pb2.StreamingQueryCommand()
cmd.process_all_available = True
self._execute_streaming_query_cmd(cmd)
processAllAvailable.__doc__ = PySparkStreamingQuery.processAllAvailable.__doc__
def stop(self) -> None:
cmd = pb2.StreamingQueryCommand()
cmd.stop = True
self._execute_streaming_query_cmd(cmd)
stop.__doc__ = PySparkStreamingQuery.stop.__doc__
def explain(self, extended: bool = False) -> None:
cmd = pb2.StreamingQueryCommand()
cmd.explain.extended = extended
result = self._execute_streaming_query_cmd(cmd).explain.result
print(result)
explain.__doc__ = PySparkStreamingQuery.explain.__doc__
def exception(self) -> Optional[StreamingQueryException]:
cmd = pb2.StreamingQueryCommand()
cmd.exception = True
exception = self._execute_streaming_query_cmd(cmd).exception
if not exception.HasField("exception_message"):
return None
else:
# Drop the Java StreamingQueryException type info
# exception_message maps to the return value of original
# StreamingQueryException's toString method
msg = exception.exception_message.split(": ", 1)[1]
if exception.HasField("stack_trace"):
msg += f"\n\nJVM stacktrace:\n{exception.stack_trace}"
return CapturedStreamingQueryException(msg, reason=exception.error_class)
exception.__doc__ = PySparkStreamingQuery.exception.__doc__
def _fetch_status(self) -> pb2.StreamingQueryCommandResult.StatusResult:
cmd = pb2.StreamingQueryCommand()
cmd.status = True
return self._execute_streaming_query_cmd(cmd).status
def _execute_streaming_query_cmd(
self, cmd: pb2.StreamingQueryCommand
) -> pb2.StreamingQueryCommandResult:
cmd.query_id.id = self._query_id
cmd.query_id.run_id = self._run_id
exec_cmd = pb2.Command()
exec_cmd.streaming_query_command.CopyFrom(cmd)
(_, properties) = self._session.client.execute_command(exec_cmd)
return cast(pb2.StreamingQueryCommandResult, properties["streaming_query_command_result"])
class StreamingQueryManager:
def __init__(self, session: "SparkSession") -> None:
self._session = session
@property
def active(self) -> List[StreamingQuery]:
cmd = pb2.StreamingQueryManagerCommand()
cmd.active = True
queries = self._execute_streaming_query_manager_cmd(cmd).active.active_queries
return [StreamingQuery(self._session, q.id.id, q.id.run_id, q.name) for q in queries]
active.__doc__ = PySparkStreamingQueryManager.active.__doc__
def get(self, id: str) -> Optional[StreamingQuery]:
cmd = pb2.StreamingQueryManagerCommand()
cmd.get_query = id
response = self._execute_streaming_query_manager_cmd(cmd)
if response.HasField("query"):
query = response.query
return StreamingQuery(self._session, query.id.id, query.id.run_id, query.name)
else:
return None
get.__doc__ = PySparkStreamingQueryManager.get.__doc__
def awaitAnyTermination(self, timeout: Optional[int] = None) -> Optional[bool]:
cmd = pb2.StreamingQueryManagerCommand()
if timeout is not None:
if not isinstance(timeout, (int, float)) or timeout <= 0:
raise PySparkValueError(
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "timeout", "arg_value": type(timeout).__name__},
)
cmd.await_any_termination.timeout_ms = int(timeout * 1000)
terminated = self._execute_streaming_query_manager_cmd(
cmd
).await_any_termination.terminated
return terminated
else:
await_any_termination_cmd = (
pb2.StreamingQueryManagerCommand.AwaitAnyTerminationCommand()
)
cmd.await_any_termination.CopyFrom(await_any_termination_cmd)
self._execute_streaming_query_manager_cmd(cmd)
return None
awaitAnyTermination.__doc__ = PySparkStreamingQueryManager.awaitAnyTermination.__doc__
def resetTerminated(self) -> None:
cmd = pb2.StreamingQueryManagerCommand()
cmd.reset_terminated = True
self._execute_streaming_query_manager_cmd(cmd)
resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__
def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__
def removeListener(self, listener: StreamingQueryListener) -> None:
cmd = pb2.StreamingQueryManagerCommand()
cmd.remove_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__
def _execute_streaming_query_manager_cmd(
self, cmd: pb2.StreamingQueryManagerCommand
) -> pb2.StreamingQueryManagerCommandResult:
exec_cmd = pb2.Command()
exec_cmd.streaming_query_manager_command.CopyFrom(cmd)
(_, properties) = self._session.client.execute_command(exec_cmd)
return cast(
pb2.StreamingQueryManagerCommandResult,
properties["streaming_query_manager_command_result"],
)
def _test() -> None:
import doctest
import os
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.streaming.query
os.chdir(os.environ["SPARK_HOME"])
globs = pyspark.sql.connect.streaming.query.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.streaming.query tests")
.remote("local[4]")
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.streaming.query,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()