blob: 0ac73356a31dd833528512ea7d09d35b372cba40 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from collections import defaultdict
from typing import List
import pyarrow as pa
from pypaimon.schema.data_types import PyarrowFieldParser
from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.file_store_write import FileStoreWrite
class TableWrite:
def __init__(self, table, commit_user):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.table_pyarrow_schema = PyarrowFieldParser.from_paimon_schema(self.table.table_schema.fields)
self.file_store_write = FileStoreWrite(self.table, commit_user)
self.row_key_extractor = self.table.create_row_key_extractor()
self.commit_user = commit_user
def write_arrow(self, table: pa.Table):
batches_iterator = table.to_batches()
for batch in batches_iterator:
self.write_arrow_batch(batch)
def write_arrow_batch(self, data: pa.RecordBatch):
self._validate_pyarrow_schema(data.schema)
partitions, buckets = self.row_key_extractor.extract_partition_bucket_batch(data)
partition_bucket_groups = defaultdict(list)
for i in range(data.num_rows):
partition_bucket_groups[(tuple(partitions[i]), buckets[i])].append(i)
for (partition, bucket), row_indices in partition_bucket_groups.items():
indices_array = pa.array(row_indices, type=pa.int64())
sub_table = pa.compute.take(data, indices_array)
self.file_store_write.write(partition, bucket, sub_table)
def write_pandas(self, dataframe):
pa_schema = PyarrowFieldParser.from_paimon_schema(self.table.table_schema.fields)
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=pa_schema)
return self.write_arrow_batch(record_batch)
def with_write_type(self, write_cols: List[str]):
for col in write_cols:
if col not in self.table_pyarrow_schema.names:
raise ValueError(f"Column {col} is not in table schema.")
if len(write_cols) == len(self.table_pyarrow_schema.names):
write_cols = None
self.file_store_write.write_cols = write_cols
return self
def close(self):
self.file_store_write.close()
def _validate_pyarrow_schema(self, data_schema: pa.Schema):
if data_schema != self.table_pyarrow_schema and data_schema.names != self.file_store_write.write_cols:
raise ValueError(f"Input schema isn't consistent with table schema and write cols. "
f"Input schema is: {data_schema} "
f"Table schema is: {self.table_pyarrow_schema} "
f"Write cols is: {self.file_store_write.write_cols}")
class BatchTableWrite(TableWrite):
def __init__(self, table, commit_user):
super().__init__(table, commit_user)
self.batch_committed = False
def prepare_commit(self) -> List[CommitMessage]:
if self.batch_committed:
raise RuntimeError("BatchTableWrite only supports one-time committing.")
self.batch_committed = True
return self.file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER)
class StreamTableWrite(TableWrite):
def prepare_commit(self, commit_identifier) -> List[CommitMessage]:
return self.file_store_write.prepare_commit(commit_identifier)