blob: fb3263e2824f70f8156f97758c49465c061ce486 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import warnings
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
if should_test_connect:
import grpc
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.duration_pb2 as duration_pb2
from google.rpc import status_pb2
from google.rpc import error_details_pb2
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.client.retries import (
Retrying,
DefaultPolicy,
)
from pyspark.sql.tests.connect.client.test_client import (
TestPolicy,
TestException,
)
class SleepTimeTracker:
"""Tracks sleep times in ms for testing purposes."""
def __init__(self):
self._times = []
def sleep(self, t: float):
self._times.append(int(1000 * t))
@property
def times(self):
return list(self._times)
def create_test_exception_with_details(
msg: str,
code: grpc.StatusCode = grpc.StatusCode.INTERNAL,
retry_delay: int = 0,
) -> TestException:
"""Helper function for creating TestException with additional error details
like retry_delay.
"""
retry_delay_msg = duration_pb2.Duration()
retry_delay_msg.FromMilliseconds(retry_delay)
retry_info = error_details_pb2.RetryInfo()
retry_info.retry_delay.CopyFrom(retry_delay_msg)
# Pack RetryInfo into an Any type
retry_info_any = any_pb2.Any()
retry_info_any.Pack(retry_info)
status = status_pb2.Status(
code=code.value[0],
message=msg,
details=[retry_info_any],
)
return TestException(msg=msg, code=code, trailing_status=status)
def get_client_policies_map(client: SparkConnectClient) -> dict:
return {type(policy): policy for policy in client.get_retry_policies()}
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientRetriesTestCase(unittest.TestCase):
def assertListsAlmostEqual(self, first, second, places=None, msg=None, delta=None):
self.assertEqual(len(first), len(second), msg)
for i in range(len(first)):
self.assertAlmostEqual(first[i], second[i], places, msg, delta)
def test_retry(self):
client = SparkConnectClient("sc://foo/;token=bar")
sleep_tracker = SleepTimeTracker()
try:
for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep):
with attempt:
raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE)
except TestException:
pass
# tolerated at least 10 mins of fails
self.assertGreaterEqual(sum(sleep_tracker.times), 600)
def test_retry_client_unit(self):
client = SparkConnectClient("sc://foo/;token=bar")
policyA = TestPolicy()
policyB = DefaultPolicy()
client.set_retry_policies([policyA, policyB])
self.assertEqual(client.get_retry_policies(), [policyA, policyB])
def test_warning_works(self):
client = SparkConnectClient("sc://foo/;token=bar")
policy = get_client_policies_map(client).get(DefaultPolicy)
self.assertIsNotNone(policy)
sleep_tracker = SleepTimeTracker()
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
try:
for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep):
with attempt:
raise TestException(
msg="Some error message", code=grpc.StatusCode.UNAVAILABLE
)
except TestException:
pass
self.assertEqual(len(sleep_tracker.times), policy.max_retries)
self.assertEqual(len(warning_list), 1)
self.assertEqual(
str(warning_list[0].message),
"[RETRIES_EXCEEDED] The maximum number of retries has been exceeded.",
)
def test_default_policy_retries_retry_info(self):
client = SparkConnectClient("sc://foo/;token=bar")
policy = get_client_policies_map(client).get(DefaultPolicy)
self.assertIsNotNone(policy)
# retry delay = 0, error code not matched by any policy.
# Testing if errors with RetryInfo are being retried by the DefaultPolicy.
retry_delay = 0
sleep_tracker = SleepTimeTracker()
try:
for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep):
with attempt:
raise create_test_exception_with_details(
msg="Some error message",
code=grpc.StatusCode.UNIMPLEMENTED,
retry_delay=retry_delay,
)
except TestException:
pass
expected_times = [
min(policy.max_backoff, policy.initial_backoff * policy.backoff_multiplier**i)
for i in range(policy.max_retries)
]
self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter)
def test_retry_delay_overrides_max_backoff(self):
client = SparkConnectClient("sc://foo/;token=bar")
policy = get_client_policies_map(client).get(DefaultPolicy)
self.assertIsNotNone(policy)
# retry delay = 5 mins.
# Testing if retry_delay overrides max_backoff.
retry_delay = 5 * 60 * 1000
sleep_tracker = SleepTimeTracker()
# assert that retry_delay is greater than max_backoff to make sure the test is valid
self.assertGreaterEqual(retry_delay, policy.max_backoff)
try:
for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep):
with attempt:
raise create_test_exception_with_details(
"Some error message",
grpc.StatusCode.UNAVAILABLE,
retry_delay,
)
except TestException:
pass
expected_times = [retry_delay] * policy.max_retries
self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter)
def test_max_server_retry_delay(self):
client = SparkConnectClient("sc://foo/;token=bar")
policy = get_client_policies_map(client).get(DefaultPolicy)
self.assertIsNotNone(policy)
# retry delay = 10 hours
# Testing if max_server_retry_delay limit works.
retry_delay = 10 * 60 * 60 * 1000
sleep_tracker = SleepTimeTracker()
try:
for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep):
with attempt:
raise create_test_exception_with_details(
"Some error message",
grpc.StatusCode.UNAVAILABLE,
retry_delay,
)
except TestException:
pass
expected_times = [policy.max_server_retry_delay] * policy.max_retries
self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter)
def test_return_to_exponential_backoff(self):
client = SparkConnectClient("sc://foo/;token=bar")
policy = get_client_policies_map(client).get(DefaultPolicy)
self.assertIsNotNone(policy)
# Start with retry_delay = 5 mins, then set it to zero.
# Test if backoff goes back to client's exponential strategy.
initial_retry_delay = 5 * 60 * 1000
sleep_tracker = SleepTimeTracker()
try:
for i, attempt in enumerate(
Retrying(client._retry_policies, sleep=sleep_tracker.sleep)
):
if i < 2:
retry_delay = initial_retry_delay
elif i < 5:
retry_delay = 0
else:
break
with attempt:
raise create_test_exception_with_details(
"Some error message",
grpc.StatusCode.UNAVAILABLE,
retry_delay,
)
except TestException:
pass
expected_times = [initial_retry_delay] * 2 + [
policy.initial_backoff * policy.backoff_multiplier**i for i in range(2, 5)
]
self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter)
if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client_retries import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)