blob: 9e41edb9ea8465f07e20455e077046226c6ca50f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import hashlib
import hmac
import logging
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from datetime import datetime, timezone
from typing import Dict, Optional
from pypaimon.api.token_loader import (DLFToken, DLFTokenLoader,
DLFTokenLoaderFactory)
from pypaimon.api.typedef import RESTAuthParameter
from pypaimon.common.options import Options
from pypaimon.common.options.config import CatalogOptions
class AuthProvider(ABC):
@abstractmethod
def merge_auth_header(
self, base_header: Dict[str, str], parameter: RESTAuthParameter
) -> Dict[str, str]:
"""Merge authorization header into header."""
class RESTAuthFunction:
def __init__(self,
init_header: Dict[str, str],
auth_provider: AuthProvider):
self.init_header = init_header.copy() if init_header else {}
self.auth_provider = auth_provider
def __call__(
self, rest_auth_parameter: RESTAuthParameter) -> Dict[str, str]:
return self.auth_provider.merge_auth_header(
self.init_header, rest_auth_parameter
)
def apply(self, rest_auth_parameter: RESTAuthParameter) -> Dict[str, str]:
return self.__call__(rest_auth_parameter)
class AuthProviderFactory:
@staticmethod
def create_auth_provider(options: Options) -> AuthProvider:
provider = options.get(CatalogOptions.TOKEN_PROVIDER)
if provider == 'bear':
token = options.get(CatalogOptions.TOKEN)
return BearTokenAuthProvider(token)
elif provider == 'dlf':
return DLFAuthProvider(
options.get(CatalogOptions.DLF_REGION),
DLFToken.from_options(options),
DLFTokenLoaderFactory.create_token_loader(options)
)
raise ValueError('Unknown auth provider')
class BearTokenAuthProvider(AuthProvider):
def __init__(self, token: str):
self.token = token
def merge_auth_header(
self, base_header: Dict[str, str], rest_auth_parameter: RESTAuthParameter
) -> Dict[str, str]:
headers_with_auth = base_header.copy()
headers_with_auth['Authorization'] = f'Bearer {self.token}'
return headers_with_auth
class DLFAuthProvider(AuthProvider):
DLF_AUTHORIZATION_HEADER_KEY = "Authorization"
DLF_CONTENT_MD5_HEADER_KEY = "Content-MD5"
DLF_CONTENT_TYPE_KEY = "Content-Type"
DLF_DATE_HEADER_KEY = "x-dlf-date"
DLF_SECURITY_TOKEN_HEADER_KEY = "x-dlf-security-token"
DLF_AUTH_VERSION_HEADER_KEY = "x-dlf-version"
DLF_CONTENT_SHA56_HEADER_KEY = "x-dlf-content-sha256"
DLF_CONTENT_SHA56_VALUE = "UNSIGNED-PAYLOAD"
AUTH_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ"
MEDIA_TYPE = "application/json"
TOKEN_EXPIRATION_SAFE_TIME_MILLIS = 3_600_000
token: DLFToken
token_loader: DLFTokenLoader
region: str
def __init__(self,
region: str,
token: DLFToken = None,
token_loader: DLFTokenLoader = None):
self.logger = logging.getLogger(self.__class__.__name__)
if token is None and token_loader is None:
raise ValueError("Either token or token_loader must be provided")
self.token = token
self.token_loader = token_loader
self.region = region
def get_token(self):
if self.token_loader is not None:
if self.token is None:
self.token = self.token_loader.load_token()
elif self.token is not None and self.token.expiration_at_millis is not None:
if self.token.expiration_at_millis - int(time.time() * 1000) < self.TOKEN_EXPIRATION_SAFE_TIME_MILLIS:
self.token = self.token_loader.load_token()
if self.token is None:
raise ValueError("Either token or token_loader must be provided")
return self.token
def merge_auth_header(
self, base_header: Dict[str, str], rest_auth_parameter: RESTAuthParameter
) -> Dict[str, str]:
try:
date_time = base_header.get(
self.DLF_DATE_HEADER_KEY.lower(), datetime.now(
timezone.utc).strftime(
self.AUTH_DATE_TIME_FORMAT), )
date = date_time[:8]
sign_headers = self.generate_sign_headers(
rest_auth_parameter.data,
date_time,
self.get_token().security_token
)
authorization = DLFAuthSignature.get_authorization(
rest_auth_parameter=rest_auth_parameter,
dlf_token=self.get_token(),
region=self.region,
headers=sign_headers,
date_time=date_time,
date=date,
)
headers_with_auth = base_header.copy()
headers_with_auth.update(sign_headers)
headers_with_auth[self.DLF_AUTHORIZATION_HEADER_KEY] = authorization
return headers_with_auth
except Exception as e:
raise RuntimeError(f"Failed to merge auth header: {e}")
@classmethod
def generate_sign_headers(
cls, data: Optional[str], date_time: str, security_token: Optional[str]
) -> Dict[str, str]:
sign_headers = {}
sign_headers[cls.DLF_DATE_HEADER_KEY] = date_time
sign_headers[cls.DLF_CONTENT_SHA56_HEADER_KEY] = cls.DLF_CONTENT_SHA56_VALUE
# DLFAuthSignature.VERSION
sign_headers[cls.DLF_AUTH_VERSION_HEADER_KEY] = "v1"
if data is not None and data != "":
sign_headers[cls.DLF_CONTENT_TYPE_KEY] = cls.MEDIA_TYPE
sign_headers[cls.DLF_CONTENT_MD5_HEADER_KEY] = DLFAuthSignature.md5(
data)
if security_token is not None:
sign_headers[cls.DLF_SECURITY_TOKEN_HEADER_KEY] = security_token
return sign_headers
class DLFAuthSignature:
VERSION = "v1"
SIGNATURE_ALGORITHM = "DLF4-HMAC-SHA256"
PRODUCT = "DlfNext"
HMAC_SHA256 = "sha256"
REQUEST_TYPE = "aliyun_v4_request"
SIGNATURE_KEY = "Signature"
NEW_LINE = "\n"
SIGNED_HEADERS = [
DLFAuthProvider.DLF_CONTENT_MD5_HEADER_KEY.lower(),
DLFAuthProvider.DLF_CONTENT_TYPE_KEY.lower(),
DLFAuthProvider.DLF_CONTENT_SHA56_HEADER_KEY.lower(),
DLFAuthProvider.DLF_DATE_HEADER_KEY.lower(),
DLFAuthProvider.DLF_AUTH_VERSION_HEADER_KEY.lower(),
DLFAuthProvider.DLF_SECURITY_TOKEN_HEADER_KEY.lower(),
]
@classmethod
def get_authorization(
cls,
rest_auth_parameter: RESTAuthParameter,
dlf_token: DLFToken,
region: str,
headers: Dict[str, str],
date_time: str,
date: str,
) -> str:
try:
canonical_request = cls.get_canonical_request(
rest_auth_parameter, headers)
string_to_sign = cls.NEW_LINE.join(
[
cls.SIGNATURE_ALGORITHM,
date_time,
f"{date}/{region}/{cls.PRODUCT}/{cls.REQUEST_TYPE}",
cls._sha256_hex(canonical_request),
]
)
date_key = cls._hmac_sha256(
f"aliyun_v4{dlf_token.access_key_secret}".encode("utf-8"), date
)
date_region_key = cls._hmac_sha256(date_key, region)
date_region_service_key = cls._hmac_sha256(
date_region_key, cls.PRODUCT)
signing_key = cls._hmac_sha256(
date_region_service_key, cls.REQUEST_TYPE)
result = cls._hmac_sha256(signing_key, string_to_sign)
signature = cls._hex_encode(result)
credential = (f"{cls.SIGNATURE_ALGORITHM} "
f"Credential={dlf_token.access_key_id}/{date}/{region}/{cls.PRODUCT}/{cls.REQUEST_TYPE}")
signature_part = f"{cls.SIGNATURE_KEY}={signature}"
return f"{credential},{signature_part}"
except Exception as e:
raise RuntimeError(f"Failed to generate authorization: {e}")
@classmethod
def md5(cls, raw: str) -> str:
try:
md5_hash = hashlib.md5(raw.encode("utf-8")).digest()
return base64.b64encode(md5_hash).decode("utf-8")
except Exception as e:
raise RuntimeError(f"Failed to calculate MD5: {e}")
@classmethod
def _hmac_sha256(cls, key: bytes, data: str) -> bytes:
try:
return hmac.new(key, data.encode("utf-8"), hashlib.sha256).digest()
except Exception as e:
raise RuntimeError(f"Failed to calculate HMAC-SHA256: {e}")
@classmethod
def get_canonical_request(
cls, rest_auth_parameter: RESTAuthParameter, headers: Dict[str, str]
) -> str:
canonical_request = cls.NEW_LINE.join(
[rest_auth_parameter.method, rest_auth_parameter.path]
)
canonical_query_string = cls._build_canonical_query_string(
rest_auth_parameter.parameters
)
canonical_request = cls.NEW_LINE.join(
[canonical_request, canonical_query_string]
)
sorted_signed_headers_map = cls._build_sorted_signed_headers_map(
headers)
for key, value in sorted_signed_headers_map.items():
canonical_request = cls.NEW_LINE.join(
[canonical_request, f"{key}:{value}"])
content_sha256 = headers.get(
DLFAuthProvider.DLF_CONTENT_SHA56_HEADER_KEY,
DLFAuthProvider.DLF_CONTENT_SHA56_VALUE,
)
return cls.NEW_LINE.join([canonical_request, content_sha256])
@classmethod
def _build_canonical_query_string(
cls, parameters: Optional[Dict[str, str]]) -> str:
if not parameters:
return ""
sorted_params = OrderedDict(sorted(parameters.items()))
query_parts = []
for key, value in sorted_params.items():
key = cls._trim(key)
if value is not None and value != "":
value = cls._trim(value)
query_parts.append(f"{key}={value}")
else:
query_parts.append(key)
return "&".join(query_parts)
@classmethod
def _build_sorted_signed_headers_map(
cls, headers: Optional[Dict[str, str]]
) -> OrderedDict:
sorted_headers = OrderedDict()
if headers:
for key, value in headers.items():
lower_key = key.lower()
if lower_key in cls.SIGNED_HEADERS:
sorted_headers[lower_key] = cls._trim(value)
return OrderedDict(sorted(sorted_headers.items()))
@classmethod
def _sha256_hex(cls, raw: str) -> str:
try:
sha256_hash = hashlib.sha256(raw.encode("utf-8")).digest()
return cls._hex_encode(sha256_hash)
except Exception as e:
raise RuntimeError(f"Failed to calculate SHA256: {e}")
@classmethod
def _hex_encode(cls, raw: bytes) -> str:
if raw is None:
return None
return raw.hex()
@classmethod
def _trim(cls, value: str) -> str:
return value.strip() if value else ""