blob: aedd593b702b1b75435fbfeb3d7f8114b3435e3c [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import heapq
from typing import Any, Callable, List, Optional
from pypaimon.read.reader.iface.record_iterator import RecordIterator
from pypaimon.read.reader.iface.record_reader import RecordReader
from pypaimon.schema.data_types import DataField, Keyword
from pypaimon.schema.table_schema import TableSchema
from pypaimon.table.row.internal_row import InternalRow
from pypaimon.table.row.key_value import KeyValue
class SortMergeReaderWithMinHeap(RecordReader):
"""SortMergeReader implemented with min-heap."""
def __init__(self, readers: List[RecordReader[KeyValue]], schema: TableSchema):
self.next_batch_readers = list(readers)
self.merge_function = DeduplicateMergeFunction()
if schema.partition_keys:
trimmed_primary_keys = [pk for pk in schema.primary_keys if pk not in schema.partition_keys]
if not trimmed_primary_keys:
raise ValueError(f"Primary key constraint {schema.primary_keys} same with partition fields")
else:
trimmed_primary_keys = schema.primary_keys
field_map = {field.name: field for field in schema.fields}
key_schema = [field_map[name] for name in trimmed_primary_keys if name in field_map]
self.key_comparator = builtin_key_comparator(key_schema)
self.min_heap = []
self.polled = []
def read_batch(self) -> Optional[RecordIterator]:
for reader in self.next_batch_readers:
while True:
iterator = reader.read_batch()
if iterator is None:
reader.close()
break
kv = iterator.next()
if kv is not None:
element = Element(kv, iterator, reader)
entry = HeapEntry(kv.key, element, self.key_comparator)
heapq.heappush(self.min_heap, entry)
break
self.next_batch_readers.clear()
if not self.min_heap:
return None
return SortMergeIterator(
self,
self.polled,
self.min_heap,
self.merge_function,
self.key_comparator,
)
def close(self):
for reader in self.next_batch_readers:
reader.close()
for entry in self.min_heap:
entry.element.reader.close()
for element in self.polled:
element.reader.close()
class SortMergeIterator(RecordIterator):
def __init__(self, reader, polled: List['Element'], min_heap, merge_function,
key_comparator):
self.reader = reader
self.polled = polled
self.min_heap = min_heap
self.merge_function = merge_function
self.key_comparator = key_comparator
self.released = False
def next(self):
while True:
if not self._next_impl():
return None
result = self.merge_function.get_result()
if result is not None:
return result
def _next_impl(self):
for element in self.polled:
if element.update():
entry = HeapEntry(element.kv.key, element, self.key_comparator)
heapq.heappush(self.min_heap, entry)
self.polled.clear()
if not self.min_heap:
return False
self.merge_function.reset()
key = self.min_heap[0].key
while self.min_heap and self.key_comparator(key, self.min_heap[0].key) == 0:
entry = heapq.heappop(self.min_heap)
self.merge_function.add(entry.element.kv)
self.polled.append(entry.element)
return True
class DeduplicateMergeFunction:
"""A MergeFunction where key is primary key (unique) and value is the full record, only keep the latest one."""
def __init__(self):
self.latest_kv = None
def reset(self) -> None:
self.latest_kv = None
def add(self, kv: KeyValue):
self.latest_kv = kv
def get_result(self) -> Optional[KeyValue]:
return self.latest_kv
class Element:
def __init__(self, kv: KeyValue, iterator: RecordIterator[KeyValue], reader: RecordReader[KeyValue]):
self.kv = kv
self.iterator = iterator
self.reader = reader
def update(self) -> bool:
next_kv = self.iterator.next()
if next_kv is not None:
self.kv = next_kv
return True
self.iterator = self.reader.read_batch()
if self.iterator is None:
self.reader.close()
return False
next_kv_from_new_batch = self.iterator.next()
self.kv = next_kv_from_new_batch
return True
class HeapEntry:
def __init__(self, key: InternalRow, element: Element, key_comparator):
self.key = key
self.element = element
self.key_comparator = key_comparator
def __lt__(self, other):
result = self.key_comparator(self.key, other.key)
if result < 0:
return True
elif result > 0:
return False
return self.element.kv.sequence_number < other.element.kv.sequence_number
def builtin_key_comparator(key_schema: List[DataField]) -> Callable[[Any, Any], int]:
# Precompute comparability flags to avoid repeated type checks
comparable_types = {member.value for member in Keyword if member is not Keyword.VARIANT}
comparable_flags = [field.type.type.split(' ')[0] in comparable_types for field in key_schema]
def comparator(key1: InternalRow, key2: InternalRow) -> int:
if key1 is None and key2 is None:
return 0
if key1 is None:
return -1
if key2 is None:
return 1
for i, comparable in enumerate(comparable_flags):
val1 = key1.get_field(i)
val2 = key2.get_field(i)
if val1 is None and val2 is None:
continue
if val1 is None:
return -1
if val2 is None:
return 1
if not comparable:
raise ValueError(f"Unsupported {key_schema[i].type} comparison")
if val1 < val2:
return -1
elif val1 > val2:
return 1
return 0
return comparator