blob: 18d9a62bf46e9caf991215f56dfcce7ac687514d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession as PySparkSession
from pyspark.testing.connectutils import ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.testing.utils import eventually
@unittest.skipIf(is_remote_only(), "Requires JVM access")
class SparkConnectReattachTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils):
def test_release_sessions(self):
big_enough_query = "select * from range(1000000)"
query1 = self.connect.sql(big_enough_query).toLocalIterator()
query2 = self.connect.sql(big_enough_query).toLocalIterator()
query3 = self.connect.sql("select 1").toLocalIterator()
next(query1)
next(query2)
jvm = PySparkSession._instantiatedSession._jvm # type: ignore[union-attr]
service = getattr(
getattr(
jvm.org.apache.spark.sql.connect.service, # type: ignore[union-attr]
"SparkConnectService$",
),
"MODULE$",
)
@eventually(catch_assertions=True)
def wait_for_requests():
self.assertEqual(service.executionManager().listExecuteHolders().length(), 2)
wait_for_requests()
# Close session
self.connect.client.release_session()
# Calling release session again should be a no-op.
self.connect.client.release_session()
@eventually(catch_assertions=True)
def wait_for_responses():
self.assertEqual(service.executionManager().listExecuteHolders().length(), 0)
wait_for_responses()
# query1 and query2 could get either an:
# OPERATION_CANCELED if it happens fast - when closing the session interrupted the queries,
# and that error got pushed to the client buffers before the client got disconnected.
# INVALID_HANDLE.SESSION_CLOSED if it happens slow - when closing the session interrupted
# the client RPCs before it pushed out the error above. The client would then get an
# INVALID_CURSOR.DISCONNECTED, which it will retry with a ReattachExecute, and then get an
# INVALID_HANDLE.SESSION_CLOSED.
def check_error(q):
try:
list(q) # Iterate all.
except Exception as e: # noqa: F841
return e
e = check_error(query1)
self.assertIsNotNone(e, "An exception has to be thrown")
self.assertTrue("OPERATION_CANCELED" in str(e) or "INVALID_HANDLE.SESSION_CLOSED" in str(e))
e = check_error(query2)
self.assertIsNotNone(e, "An exception has to be thrown")
self.assertTrue("OPERATION_CANCELED" in str(e) or "INVALID_HANDLE.SESSION_CLOSED" in str(e))
# query3 has not been submitted before, so it should now fail with SESSION_CLOSED
e = check_error(query3)
self.assertIsNotNone(3, "An exception has to be thrown")
self.assertIn("INVALID_HANDLE.SESSION_CLOSED", str(e))
if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_reattach import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)