| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import unittest |
| import uuid |
| from collections.abc import Generator |
| from typing import Optional, Any, Union |
| |
| from pyspark.testing.connectutils import should_test_connect, connect_requirement_message |
| from pyspark.testing.utils import eventually |
| |
| if should_test_connect: |
| import grpc |
| from google.rpc import status_pb2 |
| import pandas as pd |
| import pyarrow as pa |
| from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder |
| from pyspark.sql.connect.client.retries import ( |
| Retrying, |
| DefaultPolicy, |
| ) |
| from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator |
| from pyspark.sql.connect.session import SparkSession as RemoteSparkSession |
| from pyspark.errors import PySparkRuntimeError |
| import pyspark.sql.connect.proto as proto |
| |
| class TestPolicy(DefaultPolicy): |
| def __init__(self): |
| super().__init__( |
| max_retries=3, |
| backoff_multiplier=4.0, |
| initial_backoff=10, |
| max_backoff=10, |
| jitter=10, |
| min_jitter_threshold=10, |
| ) |
| |
| class TestException(grpc.RpcError, grpc.Call): |
| """Exception mock to test retryable exceptions.""" |
| |
| def __init__( |
| self, |
| msg, |
| code=grpc.StatusCode.INTERNAL, |
| trailing_status: Union[status_pb2.Status, None] = None, |
| ): |
| self.msg = msg |
| self._code = code |
| self._trailer: dict[str, Any] = {} |
| if trailing_status is not None: |
| self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString() |
| |
| def code(self): |
| return self._code |
| |
| def __str__(self): |
| return self.msg |
| |
| def details(self): |
| return self.msg |
| |
| def trailing_metadata(self): |
| return None if not self._trailer else self._trailer.items() |
| |
| class ResponseGenerator(Generator): |
| """This class is used to generate values that are returned by the streaming |
| iterator of the GRPC stub.""" |
| |
| def __init__(self, funs): |
| self._funs = funs |
| self._iterator = iter(self._funs) |
| |
| def send(self, value: Any) -> proto.ExecutePlanResponse: |
| val = next(self._iterator) |
| if callable(val): |
| return val() |
| else: |
| return val |
| |
| def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: |
| super().throw(type, value, traceback) |
| |
| def close(self) -> None: |
| return super().close() |
| |
| class MockSparkConnectStub: |
| """Simple mock class for the GRPC stub used by the re-attachable execution.""" |
| |
| def __init__(self, execute_ops=None, attach_ops=None): |
| self._execute_ops = execute_ops |
| self._attach_ops = attach_ops |
| # Call counters |
| self.execute_calls = 0 |
| self.release_calls = 0 |
| self.release_until_calls = 0 |
| self.attach_calls = 0 |
| |
| def ExecutePlan(self, *args, **kwargs): |
| self.execute_calls += 1 |
| return self._execute_ops |
| |
| def ReattachExecute(self, *args, **kwargs): |
| self.attach_calls += 1 |
| return self._attach_ops |
| |
| def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): |
| if req.HasField("release_all"): |
| self.release_calls += 1 |
| elif req.HasField("release_until"): |
| print("increment") |
| self.release_until_calls += 1 |
| |
| class MockService: |
| # Simplest mock of the SparkConnectService. |
| # If this needs more complex logic, it needs to be replaced with Python mocking. |
| |
| req: Optional[proto.ExecutePlanRequest] |
| |
| def __init__(self, session_id: str): |
| self._session_id = session_id |
| self.req = None |
| |
| def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): |
| self.req = req |
| resp = proto.ExecutePlanResponse() |
| resp.session_id = self._session_id |
| resp.operation_id = req.operation_id |
| |
| pdf = pd.DataFrame(data={"col1": [1, 2]}) |
| schema = pa.Schema.from_pandas(pdf) |
| table = pa.Table.from_pandas(pdf) |
| sink = pa.BufferOutputStream() |
| |
| writer = pa.ipc.new_stream(sink, schema=schema) |
| writer.write(table) |
| writer.close() |
| |
| buf = sink.getvalue() |
| resp.arrow_batch.data = buf.to_pybytes() |
| resp.arrow_batch.row_count = 2 |
| return [resp] |
| |
| def Interrupt(self, req: proto.InterruptRequest, metadata): |
| self.req = req |
| resp = proto.InterruptResponse() |
| resp.session_id = self._session_id |
| return resp |
| |
| # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) |
| # and it blocks the test process exiting because it is registered as the atexit handler |
| # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. |
| SparkConnectClient._cleanup_ml_cache = lambda _: None |
| |
| |
| @unittest.skipIf(not should_test_connect, connect_requirement_message) |
| class SparkConnectClientTestCase(unittest.TestCase): |
| def test_user_agent_passthrough(self): |
| client = SparkConnectClient("sc://foo/;user_agent=bar", use_reattachable_execute=False) |
| mock = MockService(client._session_id) |
| client._stub = mock |
| |
| command = proto.Command() |
| client.execute_command(command) |
| |
| self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") |
| self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") |
| |
| def test_user_agent_default(self): |
| client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) |
| mock = MockService(client._session_id) |
| client._stub = mock |
| |
| command = proto.Command() |
| client.execute_command(command) |
| |
| self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") |
| self.assertRegex( |
| mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" |
| ) |
| |
| def test_properties(self): |
| client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) |
| self.assertEqual(client.token, "bar") |
| self.assertEqual(client.host, "foo") |
| |
| client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) |
| self.assertIsNone(client.token) |
| |
| def test_channel_builder(self): |
| class CustomChannelBuilder(DefaultChannelBuilder): |
| @property |
| def userId(self) -> Optional[str]: |
| return "abc" |
| |
| client = SparkConnectClient( |
| CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False |
| ) |
| |
| self.assertEqual(client._user_id, "abc") |
| |
| def test_interrupt_all(self): |
| client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) |
| mock = MockService(client._session_id) |
| client._stub = mock |
| |
| client.interrupt_all() |
| self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") |
| |
| def test_is_closed(self): |
| client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) |
| |
| self.assertFalse(client.is_closed) |
| client.close() |
| self.assertTrue(client.is_closed) |
| |
| def test_channel_builder_with_session(self): |
| dummy = str(uuid.uuid4()) |
| chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") |
| client = SparkConnectClient(chan) |
| self.assertEqual(client._session_id, chan.session_id) |
| |
| def test_session_hook(self): |
| inits = 0 |
| calls = 0 |
| |
| class TestHook(RemoteSparkSession.Hook): |
| def __init__(self, _session): |
| nonlocal inits |
| inits += 1 |
| |
| def on_execute_plan(self, req): |
| nonlocal calls |
| calls += 1 |
| return req |
| |
| session = ( |
| RemoteSparkSession.builder.remote("sc://foo")._registerHook(TestHook).getOrCreate() |
| ) |
| self.assertEqual(inits, 1) |
| self.assertEqual(calls, 0) |
| session.client._stub = MockService(session.client._session_id) |
| session.client.disable_reattachable_execute() |
| |
| # Called from _execute_and_fetch_as_iterator |
| session.range(1).collect() |
| self.assertEqual(inits, 1) |
| self.assertEqual(calls, 1) |
| |
| # Called from _execute |
| session.udf.register("test_func", lambda x: x + 1) |
| self.assertEqual(inits, 1) |
| self.assertEqual(calls, 2) |
| |
| def test_custom_operation_id(self): |
| client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) |
| mock = MockService(client._session_id) |
| client._stub = mock |
| req = client._execute_plan_request_with_metadata( |
| operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b" |
| ) |
| for resp in client._stub.ExecutePlan(req, metadata=None): |
| assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b" |
| |
| |
| @unittest.skipIf(not should_test_connect, connect_requirement_message) |
| class SparkConnectClientReattachTestCase(unittest.TestCase): |
| def setUp(self) -> None: |
| self.request = proto.ExecutePlanRequest() |
| self.retrying = lambda: Retrying(TestPolicy()) |
| self.response = proto.ExecutePlanResponse( |
| response_id="1", |
| ) |
| self.finished = proto.ExecutePlanResponse( |
| result_complete=proto.ExecutePlanResponse.ResultComplete(), |
| response_id="2", |
| ) |
| |
| def _stub_with(self, execute=None, attach=None): |
| return MockSparkConnectStub( |
| execute_ops=ResponseGenerator(execute) if execute is not None else None, |
| attach_ops=ResponseGenerator(attach) if attach is not None else None, |
| ) |
| |
| def test_basic_flow(self): |
| stub = self._stub_with([self.response, self.finished]) |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| for b in ite: |
| pass |
| |
| def check_all(): |
| self.assertEqual(0, stub.attach_calls) |
| self.assertEqual(1, stub.release_until_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(1, stub.execute_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(check_all)() |
| |
| def test_fail_during_execute(self): |
| def fatal(): |
| raise TestException("Fatal") |
| |
| stub = self._stub_with([self.response, fatal]) |
| with self.assertRaises(TestException): |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| for b in ite: |
| pass |
| |
| def check(): |
| self.assertEqual(0, stub.attach_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(1, stub.release_until_calls) |
| self.assertEqual(1, stub.execute_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(check)() |
| |
| def test_fail_and_retry_during_execute(self): |
| def non_fatal(): |
| raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) |
| |
| stub = self._stub_with( |
| [self.response, non_fatal], [self.response, self.response, self.finished] |
| ) |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| for b in ite: |
| pass |
| |
| def check(): |
| self.assertEqual(1, stub.attach_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(3, stub.release_until_calls) |
| self.assertEqual(1, stub.execute_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(check)() |
| |
| def test_fail_and_retry_during_reattach(self): |
| count = 0 |
| |
| def non_fatal(): |
| nonlocal count |
| if count < 2: |
| count += 1 |
| raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) |
| else: |
| return proto.ExecutePlanResponse() |
| |
| stub = self._stub_with( |
| [self.response, non_fatal], [self.response, non_fatal, self.response, self.finished] |
| ) |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| for b in ite: |
| pass |
| |
| def check(): |
| self.assertEqual(2, stub.attach_calls) |
| self.assertEqual(3, stub.release_until_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(1, stub.execute_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(check)() |
| |
| def test_not_found_recovers(self): |
| """SPARK-48056: Assert that the client recovers from session or operation not |
| found error if no partial responses were previously received. |
| """ |
| |
| def not_found_recovers(error_code: str): |
| def not_found(): |
| raise TestException( |
| error_code, |
| grpc.StatusCode.UNAVAILABLE, |
| trailing_status=status_pb2.Status(code=14, message=error_code, details=""), |
| ) |
| |
| stub = self._stub_with([not_found, self.finished]) |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| |
| for _ in ite: |
| pass |
| |
| def checks(): |
| self.assertEqual(2, stub.execute_calls) |
| self.assertEqual(0, stub.attach_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(0, stub.release_until_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(checks)() |
| |
| parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] |
| for b in parameters: |
| not_found_recovers(b) |
| |
| def test_not_found_fails(self): |
| """SPARK-48056: Assert that the client fails from session or operation not found error |
| if a partial response was previously received. |
| """ |
| |
| def not_found_fails(error_code: str): |
| def not_found(): |
| raise TestException( |
| error_code, |
| grpc.StatusCode.UNAVAILABLE, |
| trailing_status=status_pb2.Status(code=14, message=error_code, details=""), |
| ) |
| |
| stub = self._stub_with([self.response], [not_found]) |
| |
| with self.assertRaises(PySparkRuntimeError) as e: |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| for _ in ite: |
| pass |
| |
| self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage()) |
| self.assertTrue(error_code in e.exception.getMessage()) |
| |
| def checks(): |
| self.assertEqual(1, stub.execute_calls) |
| self.assertEqual(1, stub.attach_calls) |
| self.assertEqual(1, stub.release_calls) |
| self.assertEqual(1, stub.release_until_calls) |
| |
| eventually(timeout=1, catch_assertions=True)(checks)() |
| |
| parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] |
| for b in parameters: |
| not_found_fails(b) |
| |
| def test_observed_session_id(self): |
| stub = self._stub_with([self.response, self.finished]) |
| ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) |
| session_id = "test-session-id" |
| |
| reattach = ite._create_reattach_execute_request() |
| self.assertEqual(reattach.client_observed_server_side_session_id, "") |
| |
| self.request.client_observed_server_side_session_id = session_id |
| reattach = ite._create_reattach_execute_request() |
| self.assertEqual(reattach.client_observed_server_side_session_id, session_id) |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.connect.client.test_client import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |