blob: 898d976f2628e055481366cae11cd5e3bca59460 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import random
import time
import typing
import warnings
from google.rpc import error_details_pb2
from grpc_status import rpc_status
from typing import Optional, Callable, Generator, List, Type, cast
from types import TracebackType
from pyspark.sql.connect.logging import logger
from pyspark.errors import PySparkRuntimeError
"""
This module contains retry system. The system is designed to be
significantly customizable.
A key aspect of retries is RetryPolicy class, describing a single policy.
There can be more than one policy defined at the same time. Each policy
determines which error types it can retry and how exactly.
For instance, networking errors should likely be retried differently that
remote resource being unavailable.
Given a sequence of policies, retry logic applies all of them in sequential
order, keeping track of different policies budgets.
"""
class RetryPolicy:
"""
Describes key aspects of RetryPolicy.
It's advised that different policies are implemented as different subclasses.
Parameters
----------
max_retries: int, optional
Maximum number of retries.
initial_backoff: int
Start value of the exponential backoff.
max_backoff: int, optional
Maximal value of the exponential backoff.
backoff_multiplier: float
Multiplicative base of the exponential backoff.
jitter: int
Sample a random value uniformly from the range [0, jitter] and add it to the backoff.
min_jitter_threshold: int
Minimal value of the backoff to add random jitter.
recognize_server_retry_delay: bool
Per gRPC standard, the server can send error messages that contain `RetryInfo` message
with `retry_delay` field indicating that the client should wait for at least `retry_delay`
amount of time before retrying again, see:
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field
in the backoff computation. Server's `retry_delay` can override client's `max_backoff`.
This flag does not change which errors are retried, only how the backoff is computed.
`DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`.
max_server_retry_delay: int, optional
Limit for the server-provided `retry_delay`.
"""
def __init__(
self,
max_retries: Optional[int] = None,
initial_backoff: int = 1000,
max_backoff: Optional[int] = None,
backoff_multiplier: float = 1.0,
jitter: int = 0,
min_jitter_threshold: int = 0,
recognize_server_retry_delay: bool = False,
max_server_retry_delay: Optional[int] = None,
):
self.max_retries = max_retries
self.initial_backoff = initial_backoff
self.max_backoff = max_backoff
self.backoff_multiplier = backoff_multiplier
self.jitter = jitter
self.min_jitter_threshold = min_jitter_threshold
self.recognize_server_retry_delay = recognize_server_retry_delay
self.max_server_retry_delay = max_server_retry_delay
self._name = self.__class__.__name__
@property
def name(self) -> str:
return self._name
def can_retry(self, exception: BaseException) -> bool:
return False
def to_state(self) -> "RetryPolicyState":
return RetryPolicyState(self)
class RetryPolicyState:
"""
This class represents stateful part of the specific policy.
"""
def __init__(self, policy: RetryPolicy):
self._policy = policy
# Will allow attempts [0, self._policy.max_retries)
self._attempt = 0
self._next_wait: float = self._policy.initial_backoff
@property
def policy(self) -> RetryPolicy:
return self._policy
@property
def name(self) -> str:
return self.policy.name
def can_retry(self, exception: BaseException) -> bool:
return self.policy.can_retry(exception)
def next_attempt(self, exception: Optional[BaseException] = None) -> Optional[int]:
"""
Returns
-------
Randomized time (in milliseconds) to wait until this attempt
or None if this policy doesn't allow more retries.
"""
if self.policy.max_retries is not None and self._attempt >= self.policy.max_retries:
# No more retries under this policy
return None
self._attempt += 1
wait_time = self._next_wait
# Calculate future backoff
if self.policy.max_backoff is not None:
self._next_wait = min(
float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier
)
if exception is not None and self.policy.recognize_server_retry_delay:
retry_delay = extract_retry_delay(exception)
if retry_delay is not None:
logger.debug(f"The server has sent a retry delay of {retry_delay} ms.")
if self.policy.max_server_retry_delay is not None:
retry_delay = min(retry_delay, self.policy.max_server_retry_delay)
wait_time = max(wait_time, retry_delay)
# Jitter current backoff, after the future backoff was computed
if wait_time >= self.policy.min_jitter_threshold:
wait_time += random.uniform(0, self.policy.jitter)
# Round to whole number of milliseconds
return int(wait_time)
class AttemptManager:
"""
Simple ContextManager that is used to capture the exception thrown inside the context.
"""
def __init__(self, retrying: "Retrying") -> None:
self._retrying = retrying
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exception: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exception, BaseException):
# Swallow the exception.
if self._retrying.accept_exception(exception):
return True
# Bubble up the exception.
return False
else:
self._retrying.accept_succeeded()
return None
class Retrying:
"""
This class is a point of entry into the retry logic.
The class accepts a list of retry policies and applies them in given order.
The first policy accepting an exception will be used.
If the error was matched by one policy, the other policies will be skipped.
The usage of the class should be as follows:
for attempt in Retrying(...):
with attempt:
Do something that can throw exception
In case error is considered retriable, it would be retried based on policies, and
it will be raised if the retries limit would exceed.
Exceptions not considered retriable will be passed through transparently.
"""
def __init__(
self,
policies: typing.Union[RetryPolicy, typing.Iterable[RetryPolicy]],
sleep: Callable[[float], None] = time.sleep,
) -> None:
if isinstance(policies, RetryPolicy):
policies = [policies]
self._policies: List[RetryPolicyState] = [policy.to_state() for policy in policies]
self._sleep = sleep
self._exception: Optional[BaseException] = None
self._done = False
def can_retry(self, exception: BaseException) -> bool:
if isinstance(exception, RetryException):
return True
return any(policy.can_retry(exception) for policy in self._policies)
def accept_exception(self, exception: BaseException) -> bool:
if self.can_retry(exception):
self._exception = exception
return True
return False
def accept_succeeded(self) -> None:
self._done = True
def _last_exception(self) -> BaseException:
if self._exception is None:
raise PySparkRuntimeError(
errorClass="NO_ACTIVE_EXCEPTION",
messageParameters={},
)
return self._exception
def _wait(self) -> None:
exception = self._last_exception()
if isinstance(exception, RetryException):
# Considered immediately retriable
logger.debug(f"Got error: {repr(exception)}. Retrying.")
return
# Attempt to find a policy to wait with
matched_policy = None
for policy in self._policies:
if policy.can_retry(exception):
matched_policy = policy
break
if matched_policy is not None:
wait_time = matched_policy.next_attempt(exception)
if wait_time is not None:
logger.debug(
f"Got error: {repr(exception)}. "
+ f"Will retry after {wait_time} ms (policy: {matched_policy.name})"
)
self._sleep(wait_time / 1000)
return
# Exceeded retries
logger.debug(f"Given up on retrying. error: {repr(exception)}")
warnings.warn("[RETRIES_EXCEEDED] The maximum number of retries has been exceeded.")
raise exception
def __iter__(self) -> Generator[AttemptManager, None, None]:
"""
Generator function to wrap the exception producing code block.
Returns
-------
A generator that yields the current attempt.
"""
# First attempt is free, no need to do waiting.
yield AttemptManager(self)
while not self._done:
self._wait()
yield AttemptManager(self)
class RetryException(Exception):
"""
An exception that can be thrown upstream when inside retry and which is always retryable
even without policies
"""
class DefaultPolicy(RetryPolicy):
# Please synchronize changes here with Scala side in
# org.apache.spark.sql.connect.client.RetryPolicy
#
# Note: the number of retries is selected so that the maximum tolerated wait
# is guaranteed to be at least 10 minutes
def __init__(
self,
max_retries: Optional[int] = 15,
backoff_multiplier: float = 4.0,
initial_backoff: int = 50,
max_backoff: Optional[int] = 60000,
jitter: int = 500,
min_jitter_threshold: int = 2000,
recognize_server_retry_delay: bool = True,
max_server_retry_delay: Optional[int] = 10 * 60 * 1000, # 10 minutes
):
super().__init__(
max_retries=max_retries,
backoff_multiplier=backoff_multiplier,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
jitter=jitter,
min_jitter_threshold=min_jitter_threshold,
recognize_server_retry_delay=recognize_server_retry_delay,
max_server_retry_delay=max_server_retry_delay,
)
def can_retry(self, e: BaseException) -> bool:
"""
Helper function that is used to identify if an exception thrown by the server
can be retried or not.
Parameters
----------
e : Exception
The GRPC error as received from the server. Typed as Exception, because other exception
thrown during client processing can be passed here as well.
Returns
-------
True if the exception can be retried, False otherwise.
"""
if not isinstance(e, grpc.RpcError):
return False
if e.code() in [grpc.StatusCode.INTERNAL]:
msg = str(e)
# This error happens if another RPC preempts this RPC.
if "INVALID_CURSOR.DISCONNECTED" in msg:
return True
if e.code() == grpc.StatusCode.UNAVAILABLE:
return True
if extract_retry_info(e) is not None:
# All errors messages containing `RetryInfo` should be retried.
return True
return False
def extract_retry_info(exception: BaseException) -> Optional[error_details_pb2.RetryInfo]:
"""Extract and return RetryInfo from the grpc.RpcError"""
if isinstance(exception, grpc.RpcError):
status = rpc_status.from_call(cast(grpc.Call, exception))
if status:
for d in status.details:
if d.Is(error_details_pb2.RetryInfo.DESCRIPTOR):
info = error_details_pb2.RetryInfo()
d.Unpack(info)
return info
return None
def extract_retry_delay(exception: BaseException) -> Optional[int]:
"""Extract and return RetryInfo.retry_delay in milliseconds from grpc.RpcError if present."""
retry_info = extract_retry_info(exception)
if retry_info is not None:
return retry_info.retry_delay.ToMilliseconds()
return None