blob: db36d50e910e2d82a571f992e288f04d6775de9a [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import hashlib
import json
from abc import ABC, abstractmethod
from typing import List, Tuple
import pyarrow as pa
from pypaimon.common.core_options import CoreOptions
from pypaimon.schema.table_schema import TableSchema
from pypaimon.table.bucket_mode import BucketMode
class RowKeyExtractor(ABC):
"""Base class for extracting partition and bucket information from PyArrow data."""
def __init__(self, table_schema: TableSchema):
self.table_schema = table_schema
self.partition_indices = self._get_field_indices(table_schema.partition_keys)
def extract_partition_bucket_batch(self, data: pa.RecordBatch) -> Tuple[List[Tuple], List[int]]:
partitions = self._extract_partitions_batch(data)
buckets = self._extract_buckets_batch(data)
return partitions, buckets
def _get_field_indices(self, field_names: List[str]) -> List[int]:
if not field_names:
return []
field_map = {field.name: i for i, field in enumerate(self.table_schema.fields)}
return [field_map[name] for name in field_names if name in field_map]
def _extract_partitions_batch(self, data: pa.RecordBatch) -> List[Tuple]:
if not self.partition_indices:
return [() for _ in range(data.num_rows)]
partition_columns = [data.column(i) for i in self.partition_indices]
partitions = []
for row_idx in range(data.num_rows):
partition_values = tuple(col[row_idx].as_py() for col in partition_columns)
partitions.append(partition_values)
return partitions
@abstractmethod
def _extract_buckets_batch(self, table: pa.RecordBatch) -> List[int]:
"""Extract bucket numbers for all rows. Must be implemented by subclasses."""
class FixedBucketRowKeyExtractor(RowKeyExtractor):
"""Fixed bucket mode extractor with configurable number of buckets."""
def __init__(self, table_schema: TableSchema):
super().__init__(table_schema)
self.num_buckets = int(table_schema.options.get(CoreOptions.BUCKET, -1))
if self.num_buckets <= 0:
raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}")
bucket_key_option = table_schema.options.get(CoreOptions.BUCKET_KEY, '')
if bucket_key_option.strip():
self.bucket_keys = [k.strip() for k in bucket_key_option.split(',')]
else:
self.bucket_keys = [pk for pk in table_schema.primary_keys
if pk not in table_schema.partition_keys]
self.bucket_key_indices = self._get_field_indices(self.bucket_keys)
def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
columns = [data.column(i) for i in self.bucket_key_indices]
hashes = []
for row_idx in range(data.num_rows):
row_values = tuple(col[row_idx].as_py() for col in columns)
hashes.append(self.hash(row_values))
return [abs(hash_val) % self.num_buckets for hash_val in hashes]
@staticmethod
def hash(data) -> int:
data_json = json.dumps(data)
return int(hashlib.md5(data_json.encode()).hexdigest(), 16)
class UnawareBucketRowKeyExtractor(RowKeyExtractor):
"""Extractor for unaware bucket mode (bucket = -1, no primary keys)."""
def __init__(self, table_schema: TableSchema):
super().__init__(table_schema)
num_buckets = int(table_schema.options.get(CoreOptions.BUCKET, -1))
if num_buckets != -1:
raise ValueError(f"Unaware bucket mode requires bucket = -1, got {num_buckets}")
def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
return [0] * data.num_rows
class DynamicBucketRowKeyExtractor(RowKeyExtractor):
"""
Row key extractor for dynamic bucket mode
Ensures bucket configuration is set to -1 and prevents bucket extraction
"""
def __init__(self, table_schema: 'TableSchema'):
super().__init__(table_schema)
num_buckets = int(table_schema.options.get(CoreOptions.BUCKET, -1))
if num_buckets != -1:
raise ValueError(
f"Only 'bucket' = '-1' is allowed for 'DynamicBucketRowKeyExtractor', but found: {num_buckets}"
)
def _extract_buckets_batch(self, data: pa.RecordBatch) -> int:
raise ValueError("Can't extract bucket from row in dynamic bucket mode")
class PostponeBucketRowKeyExtractor(RowKeyExtractor):
"""Extractor for unaware bucket mode (bucket = -1, no primary keys)."""
def __init__(self, table_schema: TableSchema):
super().__init__(table_schema)
num_buckets = int(table_schema.options.get(CoreOptions.BUCKET, -2))
if num_buckets != BucketMode.POSTPONE_BUCKET.value:
raise ValueError(f"Postpone bucket mode requires bucket = -2, got {num_buckets}")
def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
return [BucketMode.POSTPONE_BUCKET.value] * data.num_rows