blob: 8e3f91bde4169a655492c3e7dab3a3c81872e002 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from collections import defaultdict
from typing import List, Optional, Tuple
import pyarrow
import pyarrow as pa
from pypaimon.common.memory_size import MemorySize
from pypaimon.globalindex import Range
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
from pypaimon.read.split import DataSplit
from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.table_update_by_row_id import TableUpdateByRowId
from pypaimon.write.writer.data_writer import DataWriter
from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter
class TableUpdate:
def __init__(self, table, commit_user):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.commit_user = commit_user
self.update_cols = None
self.projection = None
def with_update_type(self, update_cols: List[str]):
for col in update_cols:
if col not in self.table.field_names:
raise ValueError(f"Column {col} is not in table schema.")
if len(update_cols) == len(self.table.field_names):
update_cols = None
self.update_cols = update_cols
return self
def with_read_projection(self, projection: List[str]):
self.projection = projection
def new_shard_updator(self, shard_num: int, total_shard_count: int):
"""Create a shard updater for scan+rewrite style updates.
Args:
shard_num: Index of this shard/subtask.
total_shard_count: Total number of shards/subtasks.
"""
return ShardTableUpdator(
self.table,
self.projection,
self.update_cols,
self.commit_user,
shard_num,
total_shard_count,
)
def update_by_arrow_with_row_id(self, table: pa.Table) -> List[CommitMessage]:
update_by_row_id = TableUpdateByRowId(self.table, self.commit_user)
update_by_row_id.update_columns(table, self.update_cols)
return update_by_row_id.commit_messages
class ShardTableUpdator:
def __init__(
self,
table,
projection: Optional[List[str]],
write_cols: List[str],
commit_user,
shard_num: int,
total_shard_count: int,
):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.projection = projection
self.write_cols = write_cols
self.commit_user = commit_user
self.total_shard_count = total_shard_count
self.shard_num = shard_num
self.write_pos = 0
self.writer: Optional[SingleWriter] = None
self.dict = defaultdict(list)
scanner = self.table.new_read_builder().new_scan().with_shard(shard_num, total_shard_count)
self.splits = scanner.plan().splits()
self.row_ranges: List[(Tuple, Range)] = []
for split in self.splits:
if not isinstance(split, DataSplit):
raise ValueError(f"Split {split} is not DataSplit.")
files = split.files
ranges = self.compute_from_files(files)
for row_range in ranges:
self.row_ranges.append((tuple(split.partition.values), row_range))
@staticmethod
def compute_from_files(files: List[DataFileMeta]) -> List[Range]:
ranges = []
for file in files:
ranges.append(Range(file.first_row_id, file.first_row_id + file.row_count - 1))
return Range.sort_and_merge_overlap(ranges, True, False)
def arrow_reader(self) -> pyarrow.ipc.RecordBatchReader:
read_builder = self.table.new_read_builder()
read_builder.with_projection(self.projection)
return read_builder.new_read().to_arrow_batch_reader(self.splits)
def prepare_commit(self) -> List[CommitMessage]:
commit_messages = []
for (partition, files) in self.dict.items():
commit_messages.append(CommitMessage(partition, 0, files))
return commit_messages
def update_by_arrow_batch(self, data: pa.RecordBatch):
self._init_writer()
capacity = self.writer.capacity()
if capacity <= 0:
raise RuntimeError("Writer has no remaining capacity.")
# Split the batch across writers.
first, rest = (data, None) if capacity >= data.num_rows else (data.slice(0, capacity), data.slice(capacity))
self.writer.write(first)
if self.writer.capacity() == 0:
self.dict[self.writer.partition()].append(self.writer.end())
self.writer = None
if rest is not None:
if self.writer is not None:
raise RuntimeError("Should not get here, rest and current writer exist in the same time.")
self.update_by_arrow_batch(rest)
def _init_writer(self):
if self.writer is None:
if self.write_pos >= len(self.row_ranges):
raise RuntimeError(
"No more row ranges to write. "
"Ensure you write exactly the same number of rows as read from this shard."
)
item = self.row_ranges[self.write_pos]
self.write_pos += 1
partition = item[0]
row_range = item[1]
writer = AppendOnlyDataWriter(self.table, partition, 0, 0, self.table.options, self.write_cols)
writer.target_file_size = MemorySize.of_mebi_bytes(999999999).get_bytes()
self.writer = SingleWriter(writer, partition, row_range.from_, row_range.to - row_range.from_ + 1)
class SingleWriter:
def __init__(self, writer: DataWriter, partition, first_row_id: int, row_count: int):
self.writer: DataWriter = writer
self._partition = partition
self.first_row_id = first_row_id
self.row_count = row_count
self.written_records_count = 0
def capacity(self) -> int:
return self.row_count - self.written_records_count
def write(self, data: pa.RecordBatch):
if data.num_rows > self.capacity():
raise Exception("Data num size exceeds capacity.")
self.written_records_count += data.num_rows
self.writer.write(data)
return
def partition(self) -> Tuple:
return self._partition
def end(self) -> DataFileMeta:
if self.capacity() != 0:
raise Exception("There still capacity left in the writer.")
files = self.writer.prepare_commit()
if len(files) != 1:
raise Exception("Should have one file.")
file = files[0]
if file.row_count != self.row_count:
raise Exception("File row count mismatch.")
file = file.assign_first_row_id(self.first_row_id)
return file