blob: ed560d14a4f217f194cc0f868b366b2adb0df822 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from typing import Any, List, Optional
import pyarrow as pa
import pyarrow.dataset as ds
from pyarrow import RecordBatch
from pypaimon.common.file_io import FileIO
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
class FormatPyArrowReader(RecordBatchReader):
"""
A Format Reader that reads record batch from a Parquet or ORC file using PyArrow,
and filters it based on the provided predicate and projection.
"""
def __init__(self, file_io: FileIO, file_format: str, file_path: str, read_fields: List[str],
push_down_predicate: Any, batch_size: int = 4096):
file_path_for_pyarrow = file_io.to_filesystem_path(file_path)
self.dataset = ds.dataset(file_path_for_pyarrow, format=file_format, filesystem=file_io.filesystem)
self.read_fields = read_fields
# Identify which fields exist in the file and which are missing
file_schema_names = set(self.dataset.schema.names)
self.existing_fields = [field for field in read_fields if field in file_schema_names]
self.missing_fields = [field for field in read_fields if field not in file_schema_names]
# Only pass existing fields to PyArrow scanner to avoid errors
self.reader = self.dataset.scanner(
columns=self.existing_fields,
filter=push_down_predicate,
batch_size=batch_size
).to_reader()
def read_arrow_batch(self) -> Optional[RecordBatch]:
try:
batch = self.reader.read_next_batch()
if not self.missing_fields:
return batch
# Create columns for missing fields with null values
missing_columns = [pa.nulls(batch.num_rows, type=pa.null()) for _ in self.missing_fields]
# Reconstruct the batch with all fields in the correct order
all_columns = []
for field_name in self.read_fields:
if field_name in self.existing_fields:
# Get the column from the existing batch
column_idx = self.existing_fields.index(field_name)
all_columns.append(batch.column(column_idx))
else:
# Get the column from missing fields
column_idx = self.missing_fields.index(field_name)
all_columns.append(missing_columns[column_idx])
# Create a new RecordBatch with all columns
return pa.RecordBatch.from_arrays(all_columns, names=self.read_fields)
except StopIteration:
return None
def close(self):
if self.reader is not None:
self.reader = None