| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| import os |
| import logging |
| import subprocess |
| from pathlib import Path |
| from typing import Optional, List, Dict, Any |
| from urllib.parse import urlparse, splitport |
| |
| import pyarrow.fs |
| import pyarrow as pa |
| from pyarrow._fs import FileSystem |
| |
| from pypaimon.pynative.common.exception import PyNativeNotImplementedError |
| from pypaimon.pynative.common.core_option import CoreOptions |
| |
| S3_ENDPOINT = "s3.endpoint" |
| S3_ACCESS_KEY_ID = "s3.access-key" |
| S3_SECRET_ACCESS_KEY = "s3.secret-key" |
| S3_SESSION_TOKEN = "s3.session.token" |
| S3_REGION = "s3.region" |
| S3_PROXY_URI = "s3.proxy.uri" |
| S3_CONNECT_TIMEOUT = "s3.connect.timeout" |
| S3_REQUEST_TIMEOUT = "s3.request.timeout" |
| S3_ROLE_ARN = "s3.role.arn" |
| S3_ROLE_SESSION_NAME = "s3.role.session.name" |
| S3_FORCE_VIRTUAL_ADDRESSING = "s3.force.virtual.addressing" |
| |
| AWS_ROLE_ARN = "aws.role.arn" |
| AWS_ROLE_SESSION_NAME = "aws.role.session.name" |
| |
| |
| class FileIO: |
| def __init__(self, warehouse: str, catalog_options: dict): |
| self.properties = catalog_options |
| self.logger = logging.getLogger(__name__) |
| scheme, netloc, path = self.parse_location(warehouse) |
| if scheme in {"oss"}: |
| self.filesystem = self._initialize_oss_fs() |
| elif scheme in {"s3", "s3a", "s3n"}: |
| self.filesystem = self._initialize_s3_fs() |
| elif scheme in {"hdfs", "viewfs"}: |
| self.filesystem = self._initialize_hdfs_fs(scheme, netloc) |
| elif scheme in {"file"}: |
| self.filesystem = self._initialize_local_fs() |
| else: |
| raise ValueError(f"Unrecognized filesystem type in URI: {scheme}") |
| |
| @staticmethod |
| def parse_location(location: str): |
| uri = urlparse(location) |
| if not uri.scheme: |
| return "file", uri.netloc, os.path.abspath(location) |
| elif uri.scheme in ("hdfs", "viewfs"): |
| return uri.scheme, uri.netloc, uri.path |
| else: |
| return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}" |
| |
| def _initialize_oss_fs(self) -> FileSystem: |
| from pyarrow.fs import S3FileSystem |
| |
| client_kwargs = { |
| "endpoint_override": self.properties.get(S3_ENDPOINT), |
| "access_key": self.properties.get(S3_ACCESS_KEY_ID), |
| "secret_key": self.properties.get(S3_SECRET_ACCESS_KEY), |
| "session_token": self.properties.get(S3_SESSION_TOKEN), |
| "region": self.properties.get(S3_REGION), |
| "force_virtual_addressing": self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING, True), |
| } |
| |
| if proxy_uri := self.properties.get(S3_PROXY_URI): |
| client_kwargs["proxy_options"] = proxy_uri |
| if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): |
| client_kwargs["connect_timeout"] = float(connect_timeout) |
| if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT): |
| client_kwargs["request_timeout"] = float(request_timeout) |
| if role_arn := self.properties.get(S3_ROLE_ARN): |
| client_kwargs["role_arn"] = role_arn |
| if session_name := self.properties.get(S3_ROLE_SESSION_NAME): |
| client_kwargs["session_name"] = session_name |
| |
| return S3FileSystem(**client_kwargs) |
| |
| def _initialize_s3_fs(self) -> FileSystem: |
| from pyarrow.fs import S3FileSystem |
| |
| client_kwargs = { |
| "endpoint_override": self.properties.get(S3_ENDPOINT), |
| "access_key": self.properties.get(S3_ACCESS_KEY_ID), |
| "secret_key": self.properties.get(S3_SECRET_ACCESS_KEY), |
| "session_token": self.properties.get(S3_SESSION_TOKEN), |
| "region": self.properties.get(S3_REGION), |
| } |
| |
| if proxy_uri := self.properties.get(S3_PROXY_URI): |
| client_kwargs["proxy_options"] = proxy_uri |
| if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): |
| client_kwargs["connect_timeout"] = float(connect_timeout) |
| if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT): |
| client_kwargs["request_timeout"] = float(request_timeout) |
| if role_arn := self.properties.get(S3_ROLE_ARN, AWS_ROLE_ARN): |
| client_kwargs["role_arn"] = role_arn |
| if session_name := self.properties.get(S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): |
| client_kwargs["session_name"] = session_name |
| |
| client_kwargs["force_virtual_addressing"] = self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING, False) |
| return S3FileSystem(**client_kwargs) |
| |
| def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem: |
| from pyarrow.fs import HadoopFileSystem |
| |
| if 'HADOOP_HOME' not in os.environ: |
| raise RuntimeError("HADOOP_HOME environment variable is not set.") |
| if 'HADOOP_CONF_DIR' not in os.environ: |
| raise RuntimeError("HADOOP_CONF_DIR environment variable is not set.") |
| |
| hadoop_home = os.environ.get("HADOOP_HOME") |
| native_lib_path = f"{hadoop_home}/lib/native" |
| os.environ['LD_LIBRARY_PATH'] = f"{native_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}" |
| |
| class_paths = subprocess.run( |
| [f'{hadoop_home}/bin/hadoop', 'classpath', '--glob'], |
| capture_output=True, |
| text=True, |
| check=True |
| ) |
| os.environ['CLASSPATH'] = class_paths.stdout.strip() |
| |
| host, port_str = splitport(netloc) |
| return HadoopFileSystem( |
| host=host, |
| port=int(port_str), |
| user=os.environ.get('HADOOP_USER_NAME', 'hadoop') |
| ) |
| |
| def _initialize_local_fs(self) -> FileSystem: |
| from pyarrow.fs import LocalFileSystem |
| |
| return LocalFileSystem() |
| |
| def new_input_stream(self, path: Path): |
| return self.filesystem.open_input_file(str(path)) |
| |
| def new_output_stream(self, path: Path): |
| parent_dir = path.parent |
| if str(parent_dir) and not self.exists(parent_dir): |
| self.mkdirs(parent_dir) |
| |
| return self.filesystem.open_output_stream(str(path)) |
| |
| def get_file_status(self, path: Path): |
| file_infos = self.filesystem.get_file_info([str(path)]) |
| return file_infos[0] |
| |
| def list_status(self, path: Path): |
| selector = pyarrow.fs.FileSelector(str(path), recursive=False, allow_not_found=True) |
| return self.filesystem.get_file_info(selector) |
| |
| def list_directories(self, path: Path): |
| file_infos = self.list_status(path) |
| return [info for info in file_infos if info.type == pyarrow.fs.FileType.Directory] |
| |
| def exists(self, path: Path) -> bool: |
| try: |
| file_info = self.filesystem.get_file_info([str(path)])[0] |
| return file_info.type != pyarrow.fs.FileType.NotFound |
| except Exception: |
| return False |
| |
| def delete(self, path: Path, recursive: bool = False) -> bool: |
| try: |
| file_info = self.filesystem.get_file_info([str(path)])[0] |
| if file_info.type == pyarrow.fs.FileType.Directory: |
| if recursive: |
| self.filesystem.delete_dir_contents(str(path)) |
| else: |
| self.filesystem.delete_dir(str(path)) |
| else: |
| self.filesystem.delete_file(str(path)) |
| return True |
| except Exception as e: |
| self.logger.warning(f"Failed to delete {path}: {e}") |
| return False |
| |
| def mkdirs(self, path: Path) -> bool: |
| try: |
| self.filesystem.create_dir(str(path), recursive=True) |
| return True |
| except Exception as e: |
| self.logger.warning(f"Failed to create directory {path}: {e}") |
| return False |
| |
| def rename(self, src: Path, dst: Path) -> bool: |
| try: |
| dst_parent = dst.parent |
| if str(dst_parent) and not self.exists(dst_parent): |
| self.mkdirs(dst_parent) |
| |
| self.filesystem.move(str(src), str(dst)) |
| return True |
| except Exception as e: |
| self.logger.warning(f"Failed to rename {src} to {dst}: {e}") |
| return False |
| |
| def delete_quietly(self, path: Path): |
| if self.logger.isEnabledFor(logging.DEBUG): |
| self.logger.debug(f"Ready to delete {path}") |
| |
| try: |
| if not self.delete(path, False) and self.exists(path): |
| self.logger.warning(f"Failed to delete file {path}") |
| except Exception: |
| self.logger.warning(f"Exception occurs when deleting file {path}", exc_info=True) |
| |
| def delete_files_quietly(self, files: List[Path]): |
| for file_path in files: |
| self.delete_quietly(file_path) |
| |
| def delete_directory_quietly(self, directory: Path): |
| if self.logger.isEnabledFor(logging.DEBUG): |
| self.logger.debug(f"Ready to delete {directory}") |
| |
| try: |
| if not self.delete(directory, True) and self.exists(directory): |
| self.logger.warning(f"Failed to delete directory {directory}") |
| except Exception: |
| self.logger.warning(f"Exception occurs when deleting directory {directory}", exc_info=True) |
| |
| def get_file_size(self, path: Path) -> int: |
| file_info = self.get_file_status(path) |
| if file_info.size is None: |
| raise ValueError(f"File size not available for {path}") |
| return file_info.size |
| |
| def is_dir(self, path: Path) -> bool: |
| file_info = self.get_file_status(path) |
| return file_info.type == pyarrow.fs.FileType.Directory |
| |
| def check_or_mkdirs(self, path: Path): |
| if self.exists(path): |
| if not self.is_dir(path): |
| raise ValueError(f"The path '{path}' should be a directory.") |
| else: |
| self.mkdirs(path) |
| |
| def read_file_utf8(self, path: Path) -> str: |
| with self.new_input_stream(path) as input_stream: |
| return input_stream.read().decode('utf-8') |
| |
| def try_to_write_atomic(self, path: Path, content: str) -> bool: |
| temp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else Path(str(path) + ".tmp") |
| success = False |
| try: |
| self.write_file(temp_path, content, False) |
| success = self.rename(temp_path, path) |
| finally: |
| if not success: |
| self.delete_quietly(temp_path) |
| return success |
| |
| def write_file(self, path: Path, content: str, overwrite: bool = False): |
| with self.new_output_stream(path) as output_stream: |
| output_stream.write(content.encode('utf-8')) |
| |
| def overwrite_file_utf8(self, path: Path, content: str): |
| with self.new_output_stream(path) as output_stream: |
| output_stream.write(content.encode('utf-8')) |
| |
| def copy_file(self, source_path: Path, target_path: Path, overwrite: bool = False): |
| if not overwrite and self.exists(target_path): |
| raise FileExistsError(f"Target file {target_path} already exists and overwrite=False") |
| |
| self.filesystem.copy_file(str(source_path), str(target_path)) |
| |
| def copy_files(self, source_directory: Path, target_directory: Path, overwrite: bool = False): |
| file_infos = self.list_status(source_directory) |
| for file_info in file_infos: |
| if file_info.type == pyarrow.fs.FileType.File: |
| source_file = Path(file_info.path) |
| target_file = target_directory / source_file.name |
| self.copy_file(source_file, target_file, overwrite) |
| |
| def read_overwritten_file_utf8(self, path: Path) -> Optional[str]: |
| retry_number = 0 |
| exception = None |
| while retry_number < 5: |
| try: |
| return self.read_file_utf8(path) |
| except FileNotFoundError: |
| return None |
| except Exception as e: |
| if not self.exists(path): |
| return None |
| |
| if (str(type(e).__name__).endswith("RemoteFileChangedException") or |
| (str(e) and "Blocklist for" in str(e) and "has changed" in str(e))): |
| exception = e |
| retry_number += 1 |
| else: |
| raise e |
| |
| if exception: |
| if isinstance(exception, Exception): |
| raise exception |
| else: |
| raise RuntimeError(exception) |
| |
| return None |
| |
| def write_parquet(self, path: Path, data: pa.RecordBatch, compression: str = 'snappy', **kwargs): |
| try: |
| import pyarrow.parquet as pq |
| |
| with self.new_output_stream(path) as output_stream: |
| with pq.ParquetWriter(output_stream, data.schema, compression=compression, **kwargs) as pw: |
| pw.write_batch(data) |
| |
| except Exception as e: |
| self.delete_quietly(path) |
| raise RuntimeError(f"Failed to write Parquet file {path}: {e}") from e |
| |
| def write_orc(self, path: Path, data: pa.RecordBatch, compression: str = 'zstd', **kwargs): |
| try: |
| import pyarrow.orc as orc |
| |
| with self.new_output_stream(path) as output_stream: |
| orc.write_table( |
| data, |
| output_stream, |
| compression=compression, |
| **kwargs |
| ) |
| |
| except Exception as e: |
| self.delete_quietly(path) |
| raise RuntimeError(f"Failed to write ORC file {path}: {e}") from e |
| |
| def write_avro(self, path: Path, table: pa.RecordBatch, schema: Optional[Dict[str, Any]] = None, **kwargs): |
| raise PyNativeNotImplementedError(CoreOptions.FILE_FORMAT_AVRO) |