blob: f2006ab5ec8bb1403fa11183118265cc16a4704e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import random
import time
import typing
from typing import Optional, Callable, Generator, List, Type
from types import TracebackType
from pyspark.sql.connect.client.logging import logger
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
"""
This module contains retry system. The system is designed to be
significantly customizable.
A key aspect of retries is RetryPolicy class, describing a single policy.
There can be more than one policy defined at the same time. Each policy
determines which error types it can retry and how exactly.
For instance, networking errors should likely be retried differently that
remote resource being unavailable.
Given a sequence of policies, retry logic applies all of them in sequential
order, keeping track of different policies budgets.
"""
class RetryPolicy:
"""
Describes key aspects of RetryPolicy.
It's advised that different policies are implemented as different subclasses.
"""
def __init__(
self,
max_retries: Optional[int] = None,
initial_backoff: int = 1000,
max_backoff: Optional[int] = None,
backoff_multiplier: float = 1.0,
jitter: int = 0,
min_jitter_threshold: int = 0,
):
self.max_retries = max_retries
self.initial_backoff = initial_backoff
self.max_backoff = max_backoff
self.backoff_multiplier = backoff_multiplier
self.jitter = jitter
self.min_jitter_threshold = min_jitter_threshold
self._name = self.__class__.__name__
@property
def name(self) -> str:
return self._name
def can_retry(self, exception: BaseException) -> bool:
return False
def to_state(self) -> "RetryPolicyState":
return RetryPolicyState(self)
class RetryPolicyState:
"""
This class represents stateful part of the specific policy.
"""
def __init__(self, policy: RetryPolicy):
self._policy = policy
# Will allow attempts [0, self._policy.max_retries)
self._attempt = 0
self._next_wait: float = self._policy.initial_backoff
@property
def policy(self) -> RetryPolicy:
return self._policy
@property
def name(self) -> str:
return self.policy.name
def can_retry(self, exception: BaseException) -> bool:
return self.policy.can_retry(exception)
def next_attempt(self) -> Optional[int]:
"""
Returns
-------
Randomized time (in milliseconds) to wait until this attempt
or None if this policy doesn't allow more retries.
"""
if self.policy.max_retries is not None and self._attempt >= self.policy.max_retries:
# No more retries under this policy
return None
self._attempt += 1
wait_time = self._next_wait
# Calculate future backoff
if self.policy.max_backoff is not None:
self._next_wait = min(
float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier
)
# Jitter current backoff, after the future backoff was computed
if wait_time >= self.policy.min_jitter_threshold:
wait_time += random.uniform(0, self.policy.jitter)
# Round to whole number of milliseconds
return int(wait_time)
class AttemptManager:
"""
Simple ContextManager that is used to capture the exception thrown inside the context.
"""
def __init__(self, retrying: "Retrying") -> None:
self._retrying = retrying
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exception: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exception, BaseException):
# Swallow the exception.
if self._retrying.accept_exception(exception):
return True
# Bubble up the exception.
return False
else:
self._retrying.accept_succeeded()
return None
class Retrying:
"""
This class is a point of entry into the retry logic.
The class accepts a list of retry policies and applies them in given order.
The first policy accepting an exception will be used.
The usage of the class should be as follows:
for attempt in Retrying(...):
with attempt:
Do something that can throw exception
In case error is considered retriable, it would be retried based on policies, and
RetriesExceeded will be raised if the retries limit would exceed.
Exceptions not considered retriable will be passed through transparently.
"""
def __init__(
self,
policies: typing.Union[RetryPolicy, typing.Iterable[RetryPolicy]],
sleep: Callable[[float], None] = time.sleep,
) -> None:
if isinstance(policies, RetryPolicy):
policies = [policies]
self._policies: List[RetryPolicyState] = [policy.to_state() for policy in policies]
self._sleep = sleep
self._exception: Optional[BaseException] = None
self._done = False
def can_retry(self, exception: BaseException) -> bool:
if isinstance(exception, RetryException):
return True
return any(policy.can_retry(exception) for policy in self._policies)
def accept_exception(self, exception: BaseException) -> bool:
if self.can_retry(exception):
self._exception = exception
return True
return False
def accept_succeeded(self) -> None:
self._done = True
def _last_exception(self) -> BaseException:
if self._exception is None:
raise PySparkRuntimeError(
errorClass="NO_ACTIVE_EXCEPTION",
messageParameters={},
)
return self._exception
def _wait(self) -> None:
exception = self._last_exception()
if isinstance(exception, RetryException):
# Considered immediately retriable
logger.debug(f"Got error: {repr(exception)}. Retrying.")
return
# Attempt to find a policy to wait with
for policy in self._policies:
if not policy.can_retry(exception):
continue
wait_time = policy.next_attempt()
if wait_time is not None:
logger.debug(
f"Got error: {repr(exception)}. "
+ f"Will retry after {wait_time} ms (policy: {policy.name})"
)
self._sleep(wait_time / 1000)
return
# Exceeded retries
logger.debug(f"Given up on retrying. error: {repr(exception)}")
raise RetriesExceeded(errorClass="RETRIES_EXCEEDED", messageParameters={}) from exception
def __iter__(self) -> Generator[AttemptManager, None, None]:
"""
Generator function to wrap the exception producing code block.
Returns
-------
A generator that yields the current attempt.
"""
# First attempt is free, no need to do waiting.
yield AttemptManager(self)
while not self._done:
self._wait()
yield AttemptManager(self)
class RetryException(Exception):
"""
An exception that can be thrown upstream when inside retry and which is always retryable
even without policies
"""
class DefaultPolicy(RetryPolicy):
# Please synchronize changes here with Scala side in
# org.apache.spark.sql.connect.client.RetryPolicy
#
# Note: the number of retries is selected so that the maximum tolerated wait
# is guaranteed to be at least 10 minutes
def __init__(
self,
max_retries: Optional[int] = 15,
backoff_multiplier: float = 4.0,
initial_backoff: int = 50,
max_backoff: Optional[int] = 60000,
jitter: int = 500,
min_jitter_threshold: int = 2000,
):
super().__init__(
max_retries=max_retries,
backoff_multiplier=backoff_multiplier,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
jitter=jitter,
min_jitter_threshold=min_jitter_threshold,
)
def can_retry(self, e: BaseException) -> bool:
"""
Helper function that is used to identify if an exception thrown by the server
can be retried or not.
Parameters
----------
e : Exception
The GRPC error as received from the server. Typed as Exception, because other exception
thrown during client processing can be passed here as well.
Returns
-------
True if the exception can be retried, False otherwise.
"""
if not isinstance(e, grpc.RpcError):
return False
if e.code() in [grpc.StatusCode.INTERNAL]:
msg = str(e)
# This error happens if another RPC preempts this RPC.
if "INVALID_CURSOR.DISCONNECTED" in msg:
return True
if e.code() == grpc.StatusCode.UNAVAILABLE:
return True
return False