blob: 104365579c2fcf190fb1dcdd2c978db9dd4745b2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import re
import time
from datetime import datetime, timezone
from typing import Dict
from pypaimon.api.auth.base import AuthProvider
from pypaimon.api.auth.dlf_signer import (
DLFDefaultSigner,
DLFOpenApiSigner,
DLFRequestSigner,
)
from pypaimon.api.token_loader import DLFToken, DLFTokenLoader
from pypaimon.api.typedef import RESTAuthParameter
class DLFAuthProvider(AuthProvider):
DLF_AUTHORIZATION_HEADER_KEY = "Authorization"
TOKEN_EXPIRATION_SAFE_TIME_MILLIS = 3_600_000
def __init__(self,
uri: str,
region: str,
signing_algorithm: str,
token: DLFToken = None,
token_loader: DLFTokenLoader = None):
self.logger = logging.getLogger(self.__class__.__name__)
if token is None and token_loader is None:
raise ValueError("Either token or token_loader must be provided")
self.token = token
self.token_loader = token_loader
self.uri = uri
self.region = region
self.signing_algorithm = signing_algorithm
self.signer = self._create_signer(signing_algorithm)
def _create_signer(self, signing_algorithm: str) -> DLFRequestSigner:
if signing_algorithm == DLFOpenApiSigner.IDENTIFIER:
return DLFOpenApiSigner()
else:
return DLFDefaultSigner(self.region)
@staticmethod
def extract_host(uri: str) -> str:
# Remove protocol (http:// or https://)
without_protocol = re.sub(r'^https?://', '', uri)
# Remove path (everything after '/')
path_index = without_protocol.find('/')
return without_protocol[:path_index] if path_index >= 0 else without_protocol
def get_token(self) -> DLFToken:
if self.token_loader is not None:
if self.token is None:
self.token = self.token_loader.load_token()
elif self.token is not None and self.token.expiration_at_millis is not None:
if self.token.expiration_at_millis - int(time.time() * 1000) < self.TOKEN_EXPIRATION_SAFE_TIME_MILLIS:
self.token = self.token_loader.load_token()
if self.token is None:
raise ValueError("Either token or token_loader must be provided")
return self.token
def merge_auth_header(
self, base_header: Dict[str, str], rest_auth_parameter: RESTAuthParameter
) -> Dict[str, str]:
try:
token = self.get_token()
now = datetime.now(timezone.utc)
host = self.extract_host(self.uri)
sign_headers = self.signer.sign_headers(
rest_auth_parameter.data,
now,
token.security_token,
host
)
authorization = self.signer.authorization(
rest_auth_parameter,
token,
host,
sign_headers
)
headers_with_auth = base_header.copy()
headers_with_auth.update(sign_headers)
headers_with_auth[self.DLF_AUTHORIZATION_HEADER_KEY] = authorization
return headers_with_auth
except Exception as e:
raise RuntimeError(f"Failed to merge auth header: {e}")