blob: b4c9e4c67c8bbc393ab37d3f283fea906b3a6e6d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
class SparkConnectCloneSessionTest(SparkConnectSQLTestCase):
def test_clone_session_basic(self):
"""Test basic session cloning functionality."""
# Set a configuration in the original session
self.connect.sql("SET spark.test.original = 'value1'")
# Verify the original session has the value
original_value = self.connect.sql("SET spark.test.original").collect()[0][1]
self.assertEqual(original_value, "'value1'")
# Clone the session
cloned_session = self.connect.cloneSession()
# Verify the configuration was copied
# (if cloning doesn't preserve dynamic configs, use a different approach)
cloned_value = cloned_session.sql("SET spark.test.original").collect()[0][1]
self.assertEqual(cloned_value, "'value1'")
# Verify that sessions are independent by setting different values
cloned_session.sql("SET spark.test.original = 'modified_clone'")
self.connect.sql("SET spark.test.original = 'modified_original'")
# Verify independence
original_final = self.connect.sql("SET spark.test.original").collect()[0][1]
cloned_final = cloned_session.sql("SET spark.test.original").collect()[0][1]
self.assertEqual(original_final, "'modified_original'")
self.assertEqual(cloned_final, "'modified_clone'")
def test_clone_session_with_custom_id(self):
"""Test cloning session with a custom session ID."""
custom_session_id = str(uuid.uuid4())
cloned_session = self.connect.cloneSession(custom_session_id)
self.assertEqual(cloned_session.session_id, custom_session_id)
self.assertNotEqual(self.connect.session_id, cloned_session.session_id)
def test_clone_session_preserves_temp_views(self):
"""Test that temporary views are preserved in cloned sessions."""
# Create a temporary view
df = self.connect.createDataFrame([(1, "a"), (2, "b")], ["id", "value"])
df.createOrReplaceTempView("test_table")
# Verify original session can access the temp view
original_count = self.connect.sql("SELECT COUNT(*) FROM test_table").collect()[0][0]
self.assertEqual(original_count, 2)
# Clone the session
cloned_session = self.connect.cloneSession()
# Create temp view in cloned session for testing
df_clone = cloned_session.createDataFrame([(3, "c"), (4, "d")], ["id", "value"])
df_clone.createOrReplaceTempView("test_table")
# Verify both sessions have independent temp views
original_count_final = self.connect.sql("SELECT COUNT(*) FROM test_table").collect()[0][0]
cloned_count_final = cloned_session.sql("SELECT COUNT(*) FROM test_table").collect()[0][0]
self.assertEqual(original_count_final, 2) # Original data
self.assertEqual(cloned_count_final, 2) # New data
def test_temp_views_independence_after_cloning(self):
"""Test that temp views are cloned and then can be modified independently."""
# Create initial temp view before cloning
df = self.connect.createDataFrame([(1, "original")], ["id", "type"])
df.createOrReplaceTempView("shared_view")
# Verify original session can access the temp view
original_count_before = self.connect.sql("SELECT COUNT(*) FROM shared_view").collect()[0][0]
self.assertEqual(original_count_before, 1)
# Clone the session - temp views should be preserved
cloned_session = self.connect.cloneSession()
# Verify cloned session can access the same temp view (cloned)
cloned_count = cloned_session.sql("SELECT COUNT(*) FROM shared_view").collect()[0][0]
self.assertEqual(cloned_count, 1)
# Now modify the temp view in each session independently
# Replace the view in the original session
original_df = self.connect.createDataFrame(
[(2, "modified_original"), (3, "another")], ["id", "type"]
)
original_df.createOrReplaceTempView("shared_view")
# Replace the view in the cloned session
cloned_session.sql(
"CREATE OR REPLACE TEMPORARY VIEW shared_view AS SELECT 4 as id, 'cloned_data' as type"
)
# Verify they now have different content (independence after cloning)
original_count_after = self.connect.sql("SELECT COUNT(*) FROM shared_view").collect()[0][0]
cloned_count_after = cloned_session.sql("SELECT COUNT(*) FROM shared_view").collect()[0][0]
self.assertEqual(original_count_after, 2) # Original session: 2 rows
self.assertEqual(cloned_count_after, 1) # Cloned session: 1 row
def test_invalid_session_id_format(self):
"""Test that invalid session ID format raises an exception."""
with self.assertRaises(Exception) as context:
self.connect.cloneSession("not-a-valid-uuid")
# Verify it contains our clone-specific error message
self.assertIn(
"INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_FORMAT", str(context.exception)
)
self.assertIn("not-a-valid-uuid", str(context.exception))
def test_clone_session_auto_generated_id(self):
"""Test that cloneSession() without arguments generates a valid UUID."""
cloned_session = self.connect.cloneSession()
# Verify different session IDs
self.assertNotEqual(self.connect.session_id, cloned_session.session_id)
# Verify the new session ID is a valid UUID
try:
uuid.UUID(cloned_session.session_id)
except ValueError:
self.fail("Generated session ID is not a valid UUID")
if __name__ == "__main__":
from pyspark.testing import main
main()