blob: 014d4f9da32e19169b253894f158a7b8542a7b51 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from typing import List, Optional
import pyarrow as pa
from pyarrow import RecordBatch
from pypaimon.read.partition_info import PartitionInfo
from pypaimon.read.reader.format_blob_reader import FormatBlobReader
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.schema.data_types import DataField, PyarrowFieldParser
from pypaimon.table.special_fields import SpecialFields
class DataFileBatchReader(RecordBatchReader):
"""
Reads record batch from files of different formats
"""
def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], partition_info: PartitionInfo,
system_primary_key: Optional[List[str]], fields: List[DataField],
max_sequence_number: int,
first_row_id: int,
row_tracking_enabled: bool,
system_fields: dict):
self.format_reader = format_reader
self.index_mapping = index_mapping
self.partition_info = partition_info
self.system_primary_key = system_primary_key
self.schema_map = {field.name: field for field in PyarrowFieldParser.from_paimon_schema(fields)}
self.row_tracking_enabled = row_tracking_enabled
self.first_row_id = first_row_id
self.max_sequence_number = max_sequence_number
self.system_fields = system_fields
def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch]:
if isinstance(self.format_reader, FormatBlobReader):
record_batch = self.format_reader.read_arrow_batch(start_idx, end_idx)
else:
record_batch = self.format_reader.read_arrow_batch()
if record_batch is None:
return None
if self.partition_info is None and self.index_mapping is None:
if self.row_tracking_enabled and self.system_fields:
record_batch = self._assign_row_tracking(record_batch)
return record_batch
inter_arrays = []
inter_names = []
num_rows = record_batch.num_rows
if self.partition_info is not None:
for i in range(self.partition_info.size()):
if self.partition_info.is_partition_row(i):
partition_value, partition_field = self.partition_info.get_partition_value(i)
const_array = pa.repeat(partition_value, num_rows)
inter_arrays.append(const_array)
inter_names.append(partition_field.name)
else:
real_index = self.partition_info.get_real_index(i)
if real_index < record_batch.num_columns:
inter_arrays.append(record_batch.column(real_index))
inter_names.append(record_batch.schema.field(real_index).name)
else:
inter_arrays = record_batch.columns
inter_names = record_batch.schema.names
if self.index_mapping is not None:
mapped_arrays = []
mapped_names = []
for i, real_index in enumerate(self.index_mapping):
if 0 <= real_index < len(inter_arrays):
mapped_arrays.append(inter_arrays[real_index])
mapped_names.append(inter_names[real_index])
else:
null_array = pa.nulls(num_rows)
mapped_arrays.append(null_array)
mapped_names.append(f"null_col_{i}")
if self.system_primary_key:
for i in range(len(self.system_primary_key)):
if not mapped_names[i].startswith("_KEY_"):
mapped_names[i] = f"_KEY_{mapped_names[i]}"
inter_arrays = mapped_arrays
inter_names = mapped_names
# to contains 'not null' property
final_fields = []
for i, name in enumerate(inter_names):
array = inter_arrays[i]
target_field = self.schema_map.get(name)
if not target_field:
target_field = pa.field(name, array.type)
final_fields.append(target_field)
final_schema = pa.schema(final_fields)
record_batch = pa.RecordBatch.from_arrays(inter_arrays, schema=final_schema)
# Handle row tracking fields
if self.row_tracking_enabled and self.system_fields:
record_batch = self._assign_row_tracking(record_batch)
return record_batch
def _assign_row_tracking(self, record_batch: RecordBatch) -> RecordBatch:
"""Assign row tracking meta fields (_ROW_ID and _SEQUENCE_NUMBER)."""
arrays = list(record_batch.columns)
# Handle _ROW_ID field
if SpecialFields.ROW_ID.name in self.system_fields.keys():
idx = self.system_fields[SpecialFields.ROW_ID.name]
# Create a new array that fills with computed row IDs
arrays[idx] = pa.array(range(self.first_row_id, self.first_row_id + record_batch.num_rows), type=pa.int64())
# Handle _SEQUENCE_NUMBER field
if SpecialFields.SEQUENCE_NUMBER.name in self.system_fields.keys():
idx = self.system_fields[SpecialFields.SEQUENCE_NUMBER.name]
# Create a new array that fills with max_sequence_number
arrays[idx] = pa.repeat(self.max_sequence_number, record_batch.num_rows)
return pa.RecordBatch.from_arrays(arrays, names=record_batch.schema.names)
def close(self) -> None:
self.format_reader.close()