blob: cb689baeaf1eed8f63a223db3f075338b0e2562b [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import logging
import os
import subprocess
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import splitport, urlparse
import pyarrow
import pyarrow.fs as pafs
from packaging.version import parse
from pyarrow._fs import FileSystem
from pypaimon.common.file_io import FileIO
from pypaimon.common.options import Options
from pypaimon.common.options.config import OssOptions, S3Options
from pypaimon.common.uri_reader import UriReaderFactory
from pypaimon.schema.data_types import DataField, AtomicType, PyarrowFieldParser
from pypaimon.table.row.blob import BlobData, BlobDescriptor, Blob
from pypaimon.table.row.generic_row import GenericRow
from pypaimon.table.row.row_kind import RowKind
from pypaimon.write.blob_format_writer import BlobFormatWriter
class PyArrowFileIO(FileIO):
def __init__(self, path: str, catalog_options: Options):
self.properties = catalog_options
self.logger = logging.getLogger(__name__)
self._pyarrow_gte_7 = parse(pyarrow.__version__) >= parse("7.0.0")
self._pyarrow_gte_8 = parse(pyarrow.__version__) >= parse("8.0.0")
scheme, netloc, _ = self.parse_location(path)
self.uri_reader_factory = UriReaderFactory(catalog_options)
self._is_oss = scheme in {"oss"}
self._oss_bucket = None
if self._is_oss:
self._oss_bucket = self._extract_oss_bucket(path)
self.filesystem = self._initialize_oss_fs(path)
elif scheme in {"s3", "s3a", "s3n"}:
self.filesystem = self._initialize_s3_fs()
elif scheme in {"hdfs", "viewfs"}:
self.filesystem = self._initialize_hdfs_fs(scheme, netloc)
else:
raise ValueError(f"Unrecognized filesystem type in URI: {scheme}")
@staticmethod
def parse_location(location: str):
uri = urlparse(location)
if not uri.scheme:
return "file", uri.netloc, os.path.abspath(location)
elif uri.scheme in ("hdfs", "viewfs"):
return uri.scheme, uri.netloc, uri.path
else:
return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}"
def _create_s3_retry_config(
self,
max_attempts: int = 10,
request_timeout: int = 60,
connect_timeout: int = 60
) -> Dict[str, Any]:
if self._pyarrow_gte_8:
config = {
'request_timeout': request_timeout,
'connect_timeout': connect_timeout
}
try:
retry_strategy = pafs.AwsStandardS3RetryStrategy(max_attempts=max_attempts)
config['retry_strategy'] = retry_strategy
except ImportError:
pass
return config
else:
return {}
def _extract_oss_bucket(self, location) -> str:
uri = urlparse(location)
if uri.scheme and uri.scheme != "oss":
raise ValueError("Not an OSS URI: {}".format(location))
netloc = uri.netloc or ""
if (getattr(uri, "username", None) or getattr(uri, "password", None)) or ("@" in netloc):
first_segment = uri.path.lstrip("/").split("/", 1)[0]
if not first_segment:
raise ValueError("Invalid OSS URI without bucket: {}".format(location))
return first_segment
host = getattr(uri, "hostname", None) or netloc
if not host:
raise ValueError("Invalid OSS URI without host: {}".format(location))
bucket = host.split(".", 1)[0]
if not bucket:
raise ValueError("Invalid OSS URI without bucket: {}".format(location))
return bucket
def _initialize_oss_fs(self, path) -> FileSystem:
client_kwargs = {
"access_key": self.properties.get(OssOptions.OSS_ACCESS_KEY_ID),
"secret_key": self.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET),
"session_token": self.properties.get(OssOptions.OSS_SECURITY_TOKEN),
"region": self.properties.get(OssOptions.OSS_REGION),
}
if self._pyarrow_gte_7:
client_kwargs['force_virtual_addressing'] = True
client_kwargs['endpoint_override'] = self.properties.get(OssOptions.OSS_ENDPOINT)
else:
client_kwargs['endpoint_override'] = (self._oss_bucket + "." +
self.properties.get(OssOptions.OSS_ENDPOINT))
retry_config = self._create_s3_retry_config()
client_kwargs.update(retry_config)
return pafs.S3FileSystem(**client_kwargs)
def _initialize_s3_fs(self) -> FileSystem:
client_kwargs = {
"endpoint_override": self.properties.get(S3Options.S3_ENDPOINT),
"access_key": self.properties.get(S3Options.S3_ACCESS_KEY_ID),
"secret_key": self.properties.get(S3Options.S3_ACCESS_KEY_SECRET),
"session_token": self.properties.get(S3Options.S3_SECURITY_TOKEN),
"region": self.properties.get(S3Options.S3_REGION),
}
if self._pyarrow_gte_7:
client_kwargs["force_virtual_addressing"] = True
retry_config = self._create_s3_retry_config()
client_kwargs.update(retry_config)
return pafs.S3FileSystem(**client_kwargs)
def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem:
if 'HADOOP_HOME' not in os.environ:
raise RuntimeError("HADOOP_HOME environment variable is not set.")
if 'HADOOP_CONF_DIR' not in os.environ:
raise RuntimeError("HADOOP_CONF_DIR environment variable is not set.")
hadoop_home = os.environ.get("HADOOP_HOME")
native_lib_path = f"{hadoop_home}/lib/native"
os.environ['LD_LIBRARY_PATH'] = f"{native_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}"
class_paths = subprocess.run(
[f'{hadoop_home}/bin/hadoop', 'classpath', '--glob'],
capture_output=True,
text=True,
check=True
)
os.environ['CLASSPATH'] = class_paths.stdout.strip()
host, port_str = splitport(netloc)
return pafs.HadoopFileSystem(
host=host,
port=int(port_str),
user=os.environ.get('HADOOP_USER_NAME', 'hadoop')
)
def new_input_stream(self, path: str):
path_str = self.to_filesystem_path(path)
return self.filesystem.open_input_file(path_str)
def new_output_stream(self, path: str):
path_str = self.to_filesystem_path(path)
if self._is_oss and not self._pyarrow_gte_7:
# For PyArrow 6.x + OSS, path_str is already just the key part
if '/' in path_str:
parent_dir = '/'.join(path_str.split('/')[:-1])
else:
parent_dir = ''
if parent_dir and not self.exists(parent_dir):
self.mkdirs(parent_dir)
else:
parent_dir = Path(path_str).parent
if str(parent_dir) and not self.exists(str(parent_dir)):
self.mkdirs(str(parent_dir))
return self.filesystem.open_output_stream(path_str)
def get_file_status(self, path: str):
path_str = self.to_filesystem_path(path)
file_infos = self.filesystem.get_file_info([path_str])
file_info = file_infos[0]
if file_info.type == pafs.FileType.NotFound:
raise FileNotFoundError(f"File {path} (resolved as {path_str}) does not exist")
return file_info
def list_status(self, path: str):
path_str = self.to_filesystem_path(path)
selector = pafs.FileSelector(path_str, recursive=False, allow_not_found=True)
return self.filesystem.get_file_info(selector)
def list_directories(self, path: str):
file_infos = self.list_status(path)
return [info for info in file_infos if info.type == pafs.FileType.Directory]
def exists(self, path: str) -> bool:
path_str = self.to_filesystem_path(path)
file_info = self.filesystem.get_file_info([path_str])[0]
return file_info.type != pafs.FileType.NotFound
def delete(self, path: str, recursive: bool = False) -> bool:
path_str = self.to_filesystem_path(path)
file_info = self.filesystem.get_file_info([path_str])[0]
if file_info.type == pafs.FileType.NotFound:
return False
if file_info.type == pafs.FileType.Directory:
if not recursive:
selector = pafs.FileSelector(path_str, recursive=False, allow_not_found=True)
dir_contents = self.filesystem.get_file_info(selector)
if len(dir_contents) > 0:
raise OSError(f"Directory {path} is not empty")
if recursive:
self.filesystem.delete_dir_contents(path_str)
self.filesystem.delete_dir(path_str)
else:
self.filesystem.delete_dir(path_str)
else:
self.filesystem.delete_file(path_str)
return True
def mkdirs(self, path: str) -> bool:
path_str = self.to_filesystem_path(path)
file_info = self.filesystem.get_file_info([path_str])[0]
if file_info.type == pafs.FileType.Directory:
return True
elif file_info.type == pafs.FileType.File:
raise FileExistsError(f"Path exists but is not a directory: {path}")
self.filesystem.create_dir(path_str, recursive=True)
return True
def rename(self, src: str, dst: str) -> bool:
dst_str = self.to_filesystem_path(dst)
dst_parent = Path(dst_str).parent
if str(dst_parent) and not self.exists(str(dst_parent)):
self.mkdirs(str(dst_parent))
src_str = self.to_filesystem_path(src)
try:
if hasattr(self.filesystem, 'rename'):
return self.filesystem.rename(src_str, dst_str)
dst_file_info = self.filesystem.get_file_info([dst_str])[0]
if dst_file_info.type != pafs.FileType.NotFound:
if dst_file_info.type == pafs.FileType.File:
return False
# Make it compatible with HadoopFileIO: if dst is an existing directory,
# dst=dst/srcFileName
src_name = Path(src_str).name
dst_str = str(Path(dst_str) / src_name)
final_dst_info = self.filesystem.get_file_info([dst_str])[0]
if final_dst_info.type != pafs.FileType.NotFound:
return False
self.filesystem.move(src_str, dst_str)
return True
except FileNotFoundError:
return False
except (PermissionError, OSError):
return False
def delete_quietly(self, path: str):
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Ready to delete {path}")
try:
if not self.delete(path, False) and self.exists(path):
self.logger.warning(f"Failed to delete file {path}")
except Exception:
self.logger.warning(f"Exception occurs when deleting file {path}", exc_info=True)
def delete_files_quietly(self, files: List[str]):
for file_path in files:
self.delete_quietly(file_path)
def delete_directory_quietly(self, directory: str):
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Ready to delete {directory}")
try:
if not self.delete(directory, True) and self.exists(directory):
self.logger.warning(f"Failed to delete directory {directory}")
except Exception:
self.logger.warning(f"Exception occurs when deleting directory {directory}", exc_info=True)
def try_to_write_atomic(self, path: str, content: str) -> bool:
if self.exists(path):
path_str = self.to_filesystem_path(path)
file_info = self.filesystem.get_file_info([path_str])[0]
if file_info.type == pafs.FileType.Directory:
return False
temp_path = path + str(uuid.uuid4()) + ".tmp"
success = False
try:
self.write_file(temp_path, content, False)
success = self.rename(temp_path, path)
finally:
if not success:
self.delete_quietly(temp_path)
return success
def copy_file(self, source_path: str, target_path: str, overwrite: bool = False):
if not overwrite and self.exists(target_path):
raise FileExistsError(f"Target file {target_path} already exists and overwrite=False")
source_str = self.to_filesystem_path(source_path)
target_str = self.to_filesystem_path(target_path)
target_parent = Path(target_str).parent
if str(target_parent) and not self.exists(str(target_parent)):
self.mkdirs(str(target_parent))
self.filesystem.copy_file(source_str, target_str)
def write_parquet(self, path: str, data: pyarrow.Table, compression: str = 'zstd',
zstd_level: int = 1, **kwargs):
try:
import pyarrow.parquet as pq
with self.new_output_stream(path) as output_stream:
if compression.lower() == 'zstd':
kwargs['compression_level'] = zstd_level
pq.write_table(data, output_stream, compression=compression, **kwargs)
except Exception as e:
self.delete_quietly(path)
raise RuntimeError(f"Failed to write Parquet file {path}: {e}") from e
def write_orc(self, path: str, data: pyarrow.Table, compression: str = 'zstd',
zstd_level: int = 1, **kwargs):
try:
"""Write ORC file using PyArrow ORC writer.
Note: PyArrow's ORC writer doesn't support compression_level parameter.
ORC files will use zstd compression with default level
(which is 3, see https://github.com/facebook/zstd/blob/dev/programs/zstdcli.c)
instead of the specified level.
"""
import sys
import pyarrow.orc as orc
with self.new_output_stream(path) as output_stream:
# Check Python version - if 3.6, don't use compression parameter
if sys.version_info[:2] == (3, 6):
orc.write_table(data, output_stream, **kwargs)
else:
orc.write_table(
data,
output_stream,
compression=compression,
**kwargs
)
except Exception as e:
self.delete_quietly(path)
raise RuntimeError(f"Failed to write ORC file {path}: {e}") from e
def write_avro(
self, path: str, data: pyarrow.Table,
avro_schema: Optional[Dict[str, Any]] = None,
compression: str = 'zstd', zstd_level: int = 1, **kwargs):
import fastavro
if avro_schema is None:
from pypaimon.schema.data_types import PyarrowFieldParser
avro_schema = PyarrowFieldParser.to_avro_schema(data.schema)
records_dict = data.to_pydict()
def record_generator():
num_rows = len(list(records_dict.values())[0])
for i in range(num_rows):
record = {}
for col in records_dict.keys():
value = records_dict[col][i]
if isinstance(value, datetime) and value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
record[col] = value
yield record
records = record_generator()
codec_map = {
'null': 'null',
'deflate': 'deflate',
'snappy': 'snappy',
'bzip2': 'bzip2',
'xz': 'xz',
'zstandard': 'zstandard',
'zstd': 'zstandard', # zstd is commonly used in Paimon
}
compression_lower = compression.lower()
codec = codec_map.get(compression_lower)
if codec is None:
raise ValueError(
f"Unsupported compression '{compression}' for Avro format. "
f"Supported compressions: {', '.join(sorted(codec_map.keys()))}."
)
with self.new_output_stream(path) as output_stream:
if codec == 'zstandard':
kwargs['codec_compression_level'] = zstd_level
fastavro.writer(output_stream, avro_schema, records, codec=codec, **kwargs)
def write_lance(self, path: str, data: pyarrow.Table, **kwargs):
try:
import lance
from pypaimon.read.reader.lance_utils import to_lance_specified
file_path_for_lance, storage_options = to_lance_specified(self, path)
writer = lance.file.LanceFileWriter(
file_path_for_lance, data.schema, storage_options=storage_options, **kwargs)
try:
# Write all batches
for batch in data.to_batches():
writer.write_batch(batch)
finally:
writer.close()
except Exception as e:
self.delete_quietly(path)
raise RuntimeError(f"Failed to write Lance file {path}: {e}") from e
def write_blob(self, path: str, data: pyarrow.Table, blob_as_descriptor: bool, **kwargs):
try:
if data.num_columns != 1:
raise RuntimeError(f"Blob format only supports a single column, got {data.num_columns} columns")
column = data.column(0)
if column.null_count > 0:
raise RuntimeError("Blob format does not support null values")
field = data.schema[0]
if pyarrow.types.is_large_binary(field.type):
fields = [DataField(0, field.name, AtomicType("BLOB"))]
else:
paimon_type = PyarrowFieldParser.to_paimon_type(field.type, field.nullable)
fields = [DataField(0, field.name, paimon_type)]
records_dict = data.to_pydict()
num_rows = data.num_rows
field_name = fields[0].name
with self.new_output_stream(path) as output_stream:
writer = BlobFormatWriter(output_stream)
for i in range(num_rows):
col_data = records_dict[field_name][i]
if hasattr(fields[0].type, 'type') and fields[0].type.type == "BLOB":
if blob_as_descriptor:
blob_descriptor = BlobDescriptor.deserialize(col_data)
uri_reader = self.uri_reader_factory.create(blob_descriptor.uri)
blob_data = Blob.from_descriptor(uri_reader, blob_descriptor)
elif isinstance(col_data, bytes):
blob_data = BlobData(col_data)
else:
if hasattr(col_data, 'as_py'):
col_data = col_data.as_py()
if isinstance(col_data, str):
col_data = col_data.encode('utf-8')
blob_data = BlobData(col_data)
row_values = [blob_data]
else:
row_values = [col_data]
row = GenericRow(row_values, fields, RowKind.INSERT)
writer.add_element(row)
writer.close()
except Exception as e:
self.delete_quietly(path)
raise RuntimeError(f"Failed to write blob file {path}: {e}") from e
def to_filesystem_path(self, path: str) -> str:
from pyarrow.fs import S3FileSystem
import re
parsed = urlparse(path)
normalized_path = re.sub(r'/+', '/', parsed.path) if parsed.path else ''
if parsed.scheme and len(parsed.scheme) == 1 and not parsed.netloc:
return str(path)
if parsed.scheme == 'file' and parsed.netloc and parsed.netloc.endswith(':'):
drive_letter = parsed.netloc.rstrip(':')
path_part = normalized_path.lstrip('/')
return f"{drive_letter}:/{path_part}" if path_part else f"{drive_letter}:"
if isinstance(self.filesystem, S3FileSystem):
if parsed.scheme:
if parsed.netloc:
path_part = normalized_path.lstrip('/')
if self._is_oss and not self._pyarrow_gte_7:
# For PyArrow 6.x + OSS, endpoint_override already contains bucket,
result = path_part if path_part else '.'
return result
else:
result = f"{parsed.netloc}/{path_part}" if path_part else parsed.netloc
return result
else:
result = normalized_path.lstrip('/')
return result if result else '.'
else:
return str(path)
if parsed.scheme:
if not normalized_path:
return '.'
return normalized_path
return str(path)