blob: 3ddac29eb98012248438de8195bea73d2f582d81 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
from abc import ABC, abstractmethod
from functools import partial
from typing import Optional, List
from pypaimon.pynative.common.predicate import PredicateImpl
from pypaimon.pynative.read.interval_partition import SortedRun, IntervalPartition
from pypaimon.pynative.read.partition_info import PartitionInfo
from pypaimon.pynative.read.reader.concat_batch_reader import ConcatBatchReader
from pypaimon.pynative.read.reader.filter_record_reader import FilterRecordReader
from pypaimon.pynative.read.reader.format_avro_reader import FormatAvroReader
from pypaimon.pynative.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.pynative.read.reader.concat_record_reader import ConcatRecordReader
from pypaimon.pynative.read.reader.data_file_record_reader import DataFileBatchReader
from pypaimon.pynative.read.reader.drop_delete_reader import DropDeleteRecordReader
from pypaimon.pynative.read.reader.empty_record_reader import EmptyFileRecordReader
from pypaimon.pynative.read.reader.iface.record_reader import RecordReader
from pypaimon.pynative.read.reader.key_value_unwrap_reader import KeyValueUnwrapRecordReader
from pypaimon.pynative.read.reader.key_value_wrap_reader import KeyValueWrapReader
from pypaimon.pynative.read.reader.format_pyarrow_reader import FormatPyArrowReader
from pypaimon.pynative.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap
from pypaimon.pynative.read.split_impl import SplitImpl
from pypaimon.pynative.common.data_field import DataField, DataType
KEY_PREFIX = "_KEY_"
KEY_FIELD_ID_START = 1000000
NULL_FIELD_INDEX = -1
class SplitRead(ABC):
"""Abstract base class for split reading operations."""
def __init__(self, table, predicate: Optional[PredicateImpl], read_type: List[DataField], split: SplitImpl):
from pypaimon.pynative.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.predicate = predicate
self.split = split
self.value_arity = len(read_type)
self.trimmed_primary_key = [field.name for field in self.table.table_schema.get_trimmed_primary_key_fields()]
self.read_fields = read_type
if isinstance(self, MergeFileSplitRead):
self.read_fields = self._create_key_value_fields(read_type)
@abstractmethod
def create_reader(self) -> RecordReader:
"""Create a record reader for the given split."""
def file_reader_supplier(self, file_path: str, for_merge_read: bool):
_, extension = os.path.splitext(file_path)
file_format = extension[1:]
format_reader: RecordBatchReader
if file_format == "avro":
format_reader = FormatAvroReader(self.table.file_io, file_path, self.table.primary_keys,
self._get_final_read_data_fields(), self.predicate)
elif file_format == "parquet" or file_format == "orc":
format_reader = FormatPyArrowReader(self.table.file_io, file_format, file_path, self.table.primary_keys,
self._get_final_read_data_fields(), self.predicate)
else:
raise ValueError(f"Unexpected file format: {file_format}")
index_mapping = self.create_index_mapping()
partition_info = self.create_partition_info()
if for_merge_read:
return DataFileBatchReader(format_reader, index_mapping, partition_info, self.trimmed_primary_key)
else:
return DataFileBatchReader(format_reader, index_mapping, partition_info, None)
@abstractmethod
def _get_all_data_fields(self):
"""Get all data fields"""
def _get_read_data_fields(self):
read_data_fields = []
read_field_ids = {field.id for field in self.read_fields}
for data_field in self._get_all_data_fields():
if data_field.id in read_field_ids:
read_data_fields.append(data_field)
return read_data_fields
def _create_key_value_fields(self, value_field: List[DataField]):
all_fields: List[DataField] = self.table.fields
all_data_fields = []
for field in all_fields:
if field.name in self.trimmed_primary_key:
key_field_name = f"{KEY_PREFIX}{field.name}"
key_field_id = field.id + KEY_FIELD_ID_START
key_field = DataField(key_field_id, key_field_name, field.type)
all_data_fields.append(key_field)
sequence_field = DataField(2147483646, "_SEQUENCE_NUMBER", DataType("BIGINT", nullable=False))
all_data_fields.append(sequence_field)
value_kind_field = DataField(2147483645, "_VALUE_KIND", DataType("TINYINT", nullable=False))
all_data_fields.append(value_kind_field)
for field in value_field:
all_data_fields.append(field)
return all_data_fields
def create_index_mapping(self):
base_index_mapping = self._create_base_index_mapping(self.read_fields, self._get_read_data_fields())
trimmed_key_mapping, _ = self._get_trimmed_fields(self._get_read_data_fields(), self._get_all_data_fields())
if base_index_mapping is None:
mapping = trimmed_key_mapping
elif trimmed_key_mapping is None:
mapping = base_index_mapping
else:
combined = [0] * len(base_index_mapping)
for i in range(len(base_index_mapping)):
if base_index_mapping[i] < 0:
combined[i] = base_index_mapping[i]
else:
combined[i] = trimmed_key_mapping[base_index_mapping[i]]
mapping = combined
if mapping is not None:
for i in range(len(mapping)):
if mapping[i] != i:
return mapping
return None
def _create_base_index_mapping(self, table_fields: List[DataField], data_fields: List[DataField]):
index_mapping = [0] * len(table_fields)
field_id_to_index = {field.id: i for i, field in enumerate(data_fields)}
for i, table_field in enumerate(table_fields):
field_id = table_field.id
data_field_index = field_id_to_index.get(field_id)
if data_field_index is not None:
index_mapping[i] = data_field_index
else:
index_mapping[i] = NULL_FIELD_INDEX
for i in range(len(index_mapping)):
if index_mapping[i] != i:
return index_mapping
return None
def _get_final_read_data_fields(self) -> List[str]:
_, trimmed_fields = self._get_trimmed_fields(
self._get_read_data_fields(), self._get_all_data_fields()
)
return self._remove_partition_fields(trimmed_fields)
def _remove_partition_fields(self, fields: List[DataField]) -> List[str]:
partition_keys = self.table.partition_keys
if not partition_keys:
return [field.name for field in fields]
fields_without_partition = []
for field in fields:
if field.name not in partition_keys:
fields_without_partition.append(field)
return [field.name for field in fields_without_partition]
def _get_trimmed_fields(self, read_data_fields: List[DataField],
all_data_fields: List[DataField]) -> tuple[List[int], List[DataField]]:
trimmed_mapping = [0] * len(read_data_fields)
trimmed_fields = []
field_id_to_field = {field.id: field for field in all_data_fields}
position_map = {}
for i, field in enumerate(read_data_fields):
is_key_field = field.name.startswith(KEY_PREFIX)
if is_key_field:
original_id = field.id - KEY_FIELD_ID_START
else:
original_id = field.id
original_field = field_id_to_field.get(original_id)
if original_id in position_map:
trimmed_mapping[i] = position_map[original_id]
else:
position = len(trimmed_fields)
position_map[original_id] = position
trimmed_mapping[i] = position
if is_key_field:
trimmed_fields.append(original_field)
else:
trimmed_fields.append(field)
return trimmed_mapping, trimmed_fields
def create_partition_info(self):
if not self.table.partition_keys:
return None
partition_mapping = self._construct_partition_mapping()
if not partition_mapping:
return None
return PartitionInfo(partition_mapping, self.split.partition)
def _construct_partition_mapping(self) -> List[int]:
_, trimmed_fields = self._get_trimmed_fields(
self._get_read_data_fields(), self._get_all_data_fields()
)
partition_names = self.table.partition_keys
mapping = [0] * (len(trimmed_fields) + 1)
p_count = 0
for i, field in enumerate(trimmed_fields):
if field.name in partition_names:
partition_index = partition_names.index(field.name)
mapping[i] = -(partition_index + 1)
p_count += 1
else:
mapping[i] = (i - p_count) + 1
return mapping
class RawFileSplitRead(SplitRead):
def create_reader(self) -> RecordReader:
data_readers = []
for file_path in self.split.file_paths:
supplier = partial(self.file_reader_supplier, file_path=file_path, for_merge_read=False)
data_readers.append(supplier)
if not data_readers:
return EmptyFileRecordReader()
# no need for filter, all predicates has pushed down
return ConcatBatchReader(data_readers)
def _get_all_data_fields(self):
return self.table.fields
class MergeFileSplitRead(SplitRead):
def kv_reader_supplier(self, file_path):
reader_supplier = partial(self.file_reader_supplier, file_path=file_path, for_merge_read=True)
return KeyValueWrapReader(reader_supplier(), len(self.trimmed_primary_key), self.value_arity)
def section_reader_supplier(self, section: List[SortedRun]):
readers = []
for sorter_run in section:
data_readers = []
for file in sorter_run.files:
supplier = partial(self.kv_reader_supplier, file.file_path)
data_readers.append(supplier)
readers.append(ConcatRecordReader(data_readers))
return SortMergeReaderWithMinHeap(readers, self.table.table_schema.to_schema())
def create_reader(self) -> RecordReader:
section_readers = []
sections = IntervalPartition(self.split.files).partition()
for section in sections:
supplier = partial(self.section_reader_supplier, section)
section_readers.append(supplier)
concat_reader = ConcatRecordReader(section_readers)
kv_unwrap_reader = KeyValueUnwrapRecordReader(DropDeleteRecordReader(concat_reader))
if self.predicate:
return FilterRecordReader(kv_unwrap_reader, self.predicate)
else:
return kv_unwrap_reader
def _get_all_data_fields(self):
return self._create_key_value_fields(self.table.fields)