| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| import os |
| from abc import ABC, abstractmethod |
| from functools import partial |
| from typing import List, Optional, Tuple |
| |
| from pypaimon.common.core_options import CoreOptions |
| from pypaimon.common.predicate import Predicate |
| from pypaimon.manifest.schema.data_file_meta import DataFileMeta |
| from pypaimon.read.interval_partition import IntervalPartition, SortedRun |
| from pypaimon.read.partition_info import PartitionInfo |
| from pypaimon.read.push_down_utils import trim_predicate_by_fields |
| from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader, ShardBatchReader, MergeAllBatchReader |
| from pypaimon.read.reader.concat_record_reader import ConcatRecordReader |
| from pypaimon.read.reader.data_file_batch_reader import DataFileBatchReader |
| from pypaimon.read.reader.data_evolution_merge_reader import DataEvolutionMergeReader |
| from pypaimon.read.reader.field_bunch import FieldBunch, DataBunch, BlobBunch |
| from pypaimon.read.reader.drop_delete_reader import DropDeleteRecordReader |
| from pypaimon.read.reader.empty_record_reader import EmptyFileRecordReader |
| from pypaimon.read.reader.filter_record_reader import FilterRecordReader |
| from pypaimon.read.reader.format_avro_reader import FormatAvroReader |
| from pypaimon.read.reader.format_blob_reader import FormatBlobReader |
| from pypaimon.read.reader.format_pyarrow_reader import FormatPyArrowReader |
| from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader |
| from pypaimon.read.reader.iface.record_reader import RecordReader |
| from pypaimon.read.reader.key_value_unwrap_reader import \ |
| KeyValueUnwrapRecordReader |
| from pypaimon.read.reader.key_value_wrap_reader import KeyValueWrapReader |
| from pypaimon.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap |
| from pypaimon.read.split import Split |
| from pypaimon.schema.data_types import AtomicType, DataField |
| |
| KEY_PREFIX = "_KEY_" |
| KEY_FIELD_ID_START = 1000000 |
| NULL_FIELD_INDEX = -1 |
| |
| |
| class SplitRead(ABC): |
| """Abstract base class for split reading operations.""" |
| |
| def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], split: Split): |
| from pypaimon.table.file_store_table import FileStoreTable |
| |
| self.table: FileStoreTable = table |
| self.predicate = predicate |
| self.push_down_predicate = self._push_down_predicate() |
| self.split = split |
| self.value_arity = len(read_type) |
| |
| self.trimmed_primary_key = self.table.trimmed_primary_keys |
| self.read_fields = read_type |
| if isinstance(self, MergeFileSplitRead): |
| self.read_fields = self._create_key_value_fields(read_type) |
| self.schema_id_2_fields = {} |
| |
| def _push_down_predicate(self) -> Optional[Predicate]: |
| if self.predicate is None: |
| return None |
| elif self.table.is_primary_key_table: |
| pk_predicate = trim_predicate_by_fields(self.predicate, self.table.primary_keys) |
| if not pk_predicate: |
| return None |
| return pk_predicate |
| else: |
| return self.predicate |
| |
| @abstractmethod |
| def create_reader(self) -> RecordReader: |
| """Create a record reader for the given split.""" |
| |
| def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, read_fields: List[str]): |
| (read_file_fields, read_arrow_predicate) = self._get_fields_and_predicate(file.schema_id, read_fields) |
| |
| # Use external_path if available, otherwise use file_path |
| file_path = file.external_path if file.external_path else file.file_path |
| _, extension = os.path.splitext(file_path) |
| file_format = extension[1:] |
| |
| format_reader: RecordBatchReader |
| if file_format == CoreOptions.FILE_FORMAT_AVRO: |
| format_reader = FormatAvroReader(self.table.file_io, file_path, read_file_fields, |
| self.read_fields, read_arrow_predicate) |
| elif file_format == CoreOptions.FILE_FORMAT_BLOB: |
| blob_as_descriptor = CoreOptions.blob_as_descriptor(self.table.options) |
| format_reader = FormatBlobReader(self.table.file_io, file_path, read_file_fields, |
| self.read_fields, read_arrow_predicate, blob_as_descriptor) |
| elif file_format == CoreOptions.FILE_FORMAT_PARQUET or file_format == CoreOptions.FILE_FORMAT_ORC: |
| format_reader = FormatPyArrowReader(self.table.file_io, file_format, file_path, |
| read_file_fields, read_arrow_predicate) |
| else: |
| raise ValueError(f"Unexpected file format: {file_format}") |
| |
| index_mapping = self.create_index_mapping() |
| partition_info = self.create_partition_info() |
| if for_merge_read: |
| return DataFileBatchReader(format_reader, index_mapping, partition_info, self.trimmed_primary_key, |
| self.table.table_schema.fields) |
| else: |
| return DataFileBatchReader(format_reader, index_mapping, partition_info, None, |
| self.table.table_schema.fields) |
| |
| def _get_fields_and_predicate(self, schema_id: int, read_fields): |
| key = (schema_id, tuple(read_fields)) |
| if key not in self.schema_id_2_fields: |
| schema = self.table.schema_manager.get_schema(schema_id) |
| schema_field_names = set(field.name for field in schema.fields) |
| if self.table.is_primary_key_table: |
| schema_field_names.add('_SEQUENCE_NUMBER') |
| schema_field_names.add('_VALUE_KIND') |
| read_file_fields = [read_field for read_field in read_fields if read_field in schema_field_names] |
| read_predicate = trim_predicate_by_fields(self.push_down_predicate, read_file_fields) |
| read_arrow_predicate = read_predicate.to_arrow() if read_predicate else None |
| self.schema_id_2_fields[key] = (read_file_fields, read_arrow_predicate) |
| return self.schema_id_2_fields[key] |
| |
| @abstractmethod |
| def _get_all_data_fields(self): |
| """Get all data fields""" |
| |
| def _get_read_data_fields(self): |
| read_data_fields = [] |
| read_field_ids = {field.id for field in self.read_fields} |
| for data_field in self._get_all_data_fields(): |
| if data_field.id in read_field_ids: |
| read_data_fields.append(data_field) |
| return read_data_fields |
| |
| def _create_key_value_fields(self, value_field: List[DataField]): |
| all_fields: List[DataField] = self.table.fields |
| all_data_fields = [] |
| |
| for field in all_fields: |
| if field.name in self.trimmed_primary_key: |
| key_field_name = f"{KEY_PREFIX}{field.name}" |
| key_field_id = field.id + KEY_FIELD_ID_START |
| key_field = DataField(key_field_id, key_field_name, field.type) |
| all_data_fields.append(key_field) |
| |
| sequence_field = DataField(2147483646, "_SEQUENCE_NUMBER", AtomicType("BIGINT", nullable=False)) |
| all_data_fields.append(sequence_field) |
| value_kind_field = DataField(2147483645, "_VALUE_KIND", AtomicType("TINYINT", nullable=False)) |
| all_data_fields.append(value_kind_field) |
| |
| for field in value_field: |
| all_data_fields.append(field) |
| |
| return all_data_fields |
| |
| def create_index_mapping(self): |
| base_index_mapping = self._create_base_index_mapping(self.read_fields, self._get_read_data_fields()) |
| trimmed_key_mapping, _ = self._get_trimmed_fields(self._get_read_data_fields(), self._get_all_data_fields()) |
| if base_index_mapping is None: |
| mapping = trimmed_key_mapping |
| elif trimmed_key_mapping is None: |
| mapping = base_index_mapping |
| else: |
| combined = [0] * len(base_index_mapping) |
| for i in range(len(base_index_mapping)): |
| if base_index_mapping[i] < 0: |
| combined[i] = base_index_mapping[i] |
| else: |
| combined[i] = trimmed_key_mapping[base_index_mapping[i]] |
| mapping = combined |
| |
| if mapping is not None: |
| for i in range(len(mapping)): |
| if mapping[i] != i: |
| return mapping |
| |
| return None |
| |
| def _create_base_index_mapping(self, table_fields: List[DataField], data_fields: List[DataField]): |
| index_mapping = [0] * len(table_fields) |
| field_id_to_index = {field.id: i for i, field in enumerate(data_fields)} |
| |
| for i, table_field in enumerate(table_fields): |
| field_id = table_field.id |
| data_field_index = field_id_to_index.get(field_id) |
| if data_field_index is not None: |
| index_mapping[i] = data_field_index |
| else: |
| index_mapping[i] = NULL_FIELD_INDEX |
| |
| for i in range(len(index_mapping)): |
| if index_mapping[i] != i: |
| return index_mapping |
| |
| return None |
| |
| def _get_final_read_data_fields(self) -> List[str]: |
| _, trimmed_fields = self._get_trimmed_fields( |
| self._get_read_data_fields(), self._get_all_data_fields() |
| ) |
| return self._remove_partition_fields(trimmed_fields) |
| |
| def _remove_partition_fields(self, fields: List[DataField]) -> List[str]: |
| partition_keys = self.table.partition_keys |
| if not partition_keys: |
| return [field.name for field in fields] |
| |
| fields_without_partition = [] |
| for field in fields: |
| if field.name not in partition_keys: |
| fields_without_partition.append(field) |
| |
| return [field.name for field in fields_without_partition] |
| |
| def _get_trimmed_fields(self, read_data_fields: List[DataField], |
| all_data_fields: List[DataField]) -> Tuple[List[int], List[DataField]]: |
| trimmed_mapping = [0] * len(read_data_fields) |
| trimmed_fields = [] |
| |
| field_id_to_field = {field.id: field for field in all_data_fields} |
| position_map = {} |
| for i, field in enumerate(read_data_fields): |
| is_key_field = field.name.startswith(KEY_PREFIX) |
| if is_key_field: |
| original_id = field.id - KEY_FIELD_ID_START |
| else: |
| original_id = field.id |
| original_field = field_id_to_field.get(original_id) |
| |
| if original_id in position_map: |
| trimmed_mapping[i] = position_map[original_id] |
| else: |
| position = len(trimmed_fields) |
| position_map[original_id] = position |
| trimmed_mapping[i] = position |
| if is_key_field: |
| trimmed_fields.append(original_field) |
| else: |
| trimmed_fields.append(field) |
| |
| return trimmed_mapping, trimmed_fields |
| |
| def create_partition_info(self): |
| if not self.table.partition_keys: |
| return None |
| partition_mapping = self._construct_partition_mapping() |
| if not partition_mapping: |
| return None |
| return PartitionInfo(partition_mapping, self.split.partition) |
| |
| def _construct_partition_mapping(self) -> List[int]: |
| _, trimmed_fields = self._get_trimmed_fields( |
| self._get_read_data_fields(), self._get_all_data_fields() |
| ) |
| partition_names = self.table.partition_keys |
| |
| mapping = [0] * (len(trimmed_fields) + 1) |
| p_count = 0 |
| |
| for i, field in enumerate(trimmed_fields): |
| if field.name in partition_names: |
| partition_index = partition_names.index(field.name) |
| mapping[i] = -(partition_index + 1) |
| p_count += 1 |
| else: |
| mapping[i] = (i - p_count) + 1 |
| |
| return mapping |
| |
| |
| class RawFileSplitRead(SplitRead): |
| |
| def create_reader(self) -> RecordReader: |
| data_readers = [] |
| for file in self.split.files: |
| supplier = partial( |
| self.file_reader_supplier, |
| file=file, |
| for_merge_read=False, |
| read_fields=self._get_final_read_data_fields(), |
| ) |
| data_readers.append(supplier) |
| |
| if not data_readers: |
| return EmptyFileRecordReader() |
| if self.split.split_start_row is not None: |
| concat_reader = ShardBatchReader(data_readers, self.split.split_start_row, self.split.split_end_row) |
| else: |
| concat_reader = ConcatBatchReader(data_readers) |
| # if the table is appendonly table, we don't need extra filter, all predicates has pushed down |
| if self.table.is_primary_key_table and self.predicate: |
| return FilterRecordReader(concat_reader, self.predicate) |
| else: |
| return concat_reader |
| |
| def _get_all_data_fields(self): |
| return self.table.fields |
| |
| |
| class MergeFileSplitRead(SplitRead): |
| def kv_reader_supplier(self, file): |
| reader_supplier = partial( |
| self.file_reader_supplier, |
| file=file, |
| for_merge_read=True, |
| read_fields=self._get_final_read_data_fields() |
| ) |
| return KeyValueWrapReader(reader_supplier(), len(self.trimmed_primary_key), self.value_arity) |
| |
| def section_reader_supplier(self, section: List[SortedRun]): |
| readers = [] |
| for sorter_run in section: |
| data_readers = [] |
| for file in sorter_run.files: |
| supplier = partial(self.kv_reader_supplier, file) |
| data_readers.append(supplier) |
| readers.append(ConcatRecordReader(data_readers)) |
| return SortMergeReaderWithMinHeap(readers, self.table.table_schema) |
| |
| def create_reader(self) -> RecordReader: |
| section_readers = [] |
| sections = IntervalPartition(self.split.files).partition() |
| for section in sections: |
| supplier = partial(self.section_reader_supplier, section) |
| section_readers.append(supplier) |
| concat_reader = ConcatRecordReader(section_readers) |
| kv_unwrap_reader = KeyValueUnwrapRecordReader(DropDeleteRecordReader(concat_reader)) |
| if self.predicate: |
| return FilterRecordReader(kv_unwrap_reader, self.predicate) |
| else: |
| return kv_unwrap_reader |
| |
| def _get_all_data_fields(self): |
| return self._create_key_value_fields(self.table.fields) |
| |
| |
| class DataEvolutionSplitRead(SplitRead): |
| |
| def create_reader(self) -> RecordReader: |
| files = self.split.files |
| suppliers = [] |
| |
| # Split files by row ID using the same logic as Java DataEvolutionSplitGenerator.split |
| split_by_row_id = self._split_by_row_id(files) |
| |
| for need_merge_files in split_by_row_id: |
| if len(need_merge_files) == 1 or not self.read_fields: |
| # No need to merge fields, just create a single file reader |
| suppliers.append( |
| lambda f=need_merge_files[0]: self._create_file_reader(f, self._get_final_read_data_fields()) |
| ) |
| else: |
| suppliers.append( |
| lambda files=need_merge_files: self._create_union_reader(files) |
| ) |
| if self.split.split_start_row is not None: |
| return ShardBatchReader(suppliers, self.split.split_start_row, self.split.split_end_row) |
| else: |
| return ConcatBatchReader(suppliers) |
| |
| def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]: |
| """Split files by firstRowId for data evolution.""" |
| |
| # Sort files by firstRowId and then by maxSequenceNumber |
| def sort_key(file: DataFileMeta) -> tuple: |
| first_row_id = file.first_row_id if file.first_row_id is not None else float('-inf') |
| is_blob = 1 if self._is_blob_file(file.file_name) else 0 |
| max_seq = file.max_sequence_number |
| return (first_row_id, is_blob, -max_seq) |
| |
| sorted_files = sorted(files, key=sort_key) |
| |
| # Split files by firstRowId |
| split_by_row_id = [] |
| last_row_id = -1 |
| check_row_id_start = 0 |
| current_split = [] |
| |
| for file in sorted_files: |
| first_row_id = file.first_row_id |
| if first_row_id is None: |
| split_by_row_id.append([file]) |
| continue |
| |
| if not self._is_blob_file(file.file_name) and first_row_id != last_row_id: |
| if current_split: |
| split_by_row_id.append(current_split) |
| if first_row_id < check_row_id_start: |
| raise ValueError( |
| f"There are overlapping files in the split: {files}, " |
| f"the wrong file is: {file}" |
| ) |
| current_split = [] |
| last_row_id = first_row_id |
| check_row_id_start = first_row_id + file.row_count |
| current_split.append(file) |
| |
| if current_split: |
| split_by_row_id.append(current_split) |
| |
| return split_by_row_id |
| |
| def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordReader: |
| """Create a DataEvolutionFileReader for merging multiple files.""" |
| # Split field bunches |
| fields_files = self._split_field_bunches(need_merge_files) |
| |
| # Validate row counts and first row IDs |
| row_count = fields_files[0].row_count() |
| first_row_id = fields_files[0].files()[0].first_row_id |
| |
| for bunch in fields_files: |
| if bunch.row_count() != row_count: |
| raise ValueError("All files in a field merge split should have the same row count.") |
| if bunch.files()[0].first_row_id != first_row_id: |
| raise ValueError( |
| "All files in a field merge split should have the same first row id and could not be null." |
| ) |
| |
| # Create the union reader |
| all_read_fields = self.read_fields |
| file_record_readers = [None] * len(fields_files) |
| read_field_index = [field.id for field in all_read_fields] |
| |
| # Initialize offsets |
| row_offsets = [-1] * len(all_read_fields) |
| field_offsets = [-1] * len(all_read_fields) |
| |
| for i, bunch in enumerate(fields_files): |
| first_file = bunch.files()[0] |
| |
| # Get field IDs for this bunch |
| if self._is_blob_file(first_file.file_name): |
| # For blob files, we need to get the field ID from the write columns |
| field_ids = [self._get_field_id_from_write_cols(first_file)] |
| elif first_file.write_cols: |
| field_ids = self._get_field_ids_from_write_cols(first_file.write_cols) |
| else: |
| # For regular files, get all field IDs from the schema |
| field_ids = [field.id for field in self.table.fields] |
| |
| read_fields = [] |
| for j, read_field_id in enumerate(read_field_index): |
| for field_id in field_ids: |
| if read_field_id == field_id: |
| if row_offsets[j] == -1: |
| row_offsets[j] = i |
| field_offsets[j] = len(read_fields) |
| read_fields.append(all_read_fields[j]) |
| break |
| |
| if not read_fields: |
| file_record_readers[i] = None |
| else: |
| read_field_names = self._remove_partition_fields(read_fields) |
| table_fields = self.read_fields |
| self.read_fields = read_fields # create reader based on read_fields |
| # Create reader for this bunch |
| if len(bunch.files()) == 1: |
| file_record_readers[i] = self._create_file_reader( |
| bunch.files()[0], read_field_names |
| ) |
| else: |
| # Create concatenated reader for multiple files |
| suppliers = [ |
| lambda f=file: self._create_file_reader( |
| f, read_field_names |
| ) for file in bunch.files() |
| ] |
| file_record_readers[i] = MergeAllBatchReader(suppliers) |
| self.read_fields = table_fields |
| |
| # Validate that all required fields are found |
| for i, field in enumerate(all_read_fields): |
| if row_offsets[i] == -1: |
| if not field.type.nullable: |
| raise ValueError(f"Field {field} is not null but can't find any file contains it.") |
| |
| return DataEvolutionMergeReader(row_offsets, field_offsets, file_record_readers) |
| |
| def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) -> RecordReader: |
| """Create a file reader for a single file.""" |
| return self.file_reader_supplier(file=file, for_merge_read=False, read_fields=read_fields) |
| |
| def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) -> List[FieldBunch]: |
| """Split files into field bunches.""" |
| |
| fields_files = [] |
| blob_bunch_map = {} |
| row_count = -1 |
| |
| for file in need_merge_files: |
| if self._is_blob_file(file.file_name): |
| field_id = self._get_field_id_from_write_cols(file) |
| if field_id not in blob_bunch_map: |
| blob_bunch_map[field_id] = BlobBunch(row_count) |
| blob_bunch_map[field_id].add(file) |
| else: |
| # Normal file, just add it to the current merge split |
| fields_files.append(DataBunch(file)) |
| row_count = file.row_count |
| |
| fields_files.extend(blob_bunch_map.values()) |
| return fields_files |
| |
| def _get_field_id_from_write_cols(self, file: DataFileMeta) -> int: |
| """Get field ID from write columns for blob files.""" |
| if not file.write_cols or len(file.write_cols) == 0: |
| raise ValueError("Blob file must have write columns") |
| |
| # Find the field by name in the table schema |
| field_name = file.write_cols[0] |
| for field in self.table.fields: |
| if field.name == field_name: |
| return field.id |
| raise ValueError(f"Field {field_name} not found in table schema") |
| |
| def _get_field_ids_from_write_cols(self, write_cols: List[str]) -> List[int]: |
| field_ids = [] |
| for field_name in write_cols: |
| for field in self.table.fields: |
| if field.name == field_name: |
| field_ids.append(field.id) |
| return field_ids |
| |
| @staticmethod |
| def _is_blob_file(file_name: str) -> bool: |
| """Check if a file is a blob file based on its extension.""" |
| return file_name.endswith('.blob') |
| |
| def _get_all_data_fields(self): |
| return self.table.fields |