blob: 4e02bf32f67a7da1bed1efc2440607b35e8727fa [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import shutil
import tempfile
import unittest
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.write.commit_callback import CommitCallback, CommitCallbackContext
class RecordingCallback(CommitCallback):
"""Test callback that records all invocations."""
def __init__(self):
self.contexts = []
self.closed = False
def call(self, context: CommitCallbackContext) -> None:
self.contexts.append(context)
def close(self) -> None:
self.closed = True
class CommitCallbackTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', True)
cls.pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('dt', pa.string()),
])
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_table(self, table_name, partition_keys=None, options=None):
schema = Schema.from_pyarrow_schema(
self.pa_schema, partition_keys=partition_keys or [],
options=options or {})
self.catalog.create_table(f'default.{table_name}', schema, False)
return self.catalog.get_table(f'default.{table_name}')
def test_callback_invoked_on_commit(self):
table = self._create_table('test_callback_invoked')
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
data = pa.Table.from_pydict({
'id': [1, 2],
'name': ['a', 'b'],
'dt': ['p1', 'p1'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit())
self.assertEqual(1, len(callback.contexts))
ctx = callback.contexts[0]
self.assertEqual(1, ctx.snapshot.id)
self.assertEqual('APPEND', ctx.snapshot.commit_kind)
self.assertGreater(len(ctx.commit_entries), 0)
table_write.close()
table_commit.close()
def test_callback_receives_correct_snapshot_data(self):
table = self._create_table('test_callback_snapshot_data', partition_keys=['dt'])
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
data = pa.Table.from_pydict({
'id': [1, 2, 3],
'name': ['a', 'b', 'c'],
'dt': ['p1', 'p1', 'p2'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit())
ctx = callback.contexts[0]
self.assertEqual(3, ctx.snapshot.delta_record_count)
self.assertEqual(3, ctx.snapshot.total_record_count)
table_write.close()
table_commit.close()
def test_multiple_callbacks(self):
table = self._create_table('test_multi_callbacks')
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
cb1 = RecordingCallback()
cb2 = RecordingCallback()
table_commit.add_commit_callback(cb1)
table_commit.add_commit_callback(cb2)
data = pa.Table.from_pydict({
'id': [1],
'name': ['a'],
'dt': ['p1'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit())
self.assertEqual(1, len(cb1.contexts))
self.assertEqual(1, len(cb2.contexts))
self.assertEqual(cb1.contexts[0].snapshot.id, cb2.contexts[0].snapshot.id)
table_write.close()
table_commit.close()
def test_callback_close_on_commit_close(self):
table = self._create_table('test_callback_close')
write_builder = table.new_batch_write_builder()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
self.assertFalse(callback.closed)
table_commit.close()
self.assertTrue(callback.closed)
def test_callback_not_invoked_when_no_data(self):
table = self._create_table('test_callback_no_data')
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
table_commit.commit(table_write.prepare_commit())
self.assertEqual(0, len(callback.contexts))
table_write.close()
table_commit.close()
def test_stream_commit_callback_multiple_rounds(self):
table = self._create_table('test_stream_callback')
write_builder = table.new_stream_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
for i in range(3):
data = pa.Table.from_pydict({
'id': [i],
'name': [f'name_{i}'],
'dt': ['p1'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit(i), commit_identifier=i)
self.assertEqual(3, len(callback.contexts))
for i, ctx in enumerate(callback.contexts):
self.assertEqual(i + 1, ctx.snapshot.id)
table_write.close()
table_commit.close()
def test_data_evolution_callback_sees_row_id(self):
table = self._create_table('test_de_row_id', options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
data = pa.Table.from_pydict({
'id': [1, 2, 3],
'name': ['a', 'b', 'c'],
'dt': ['p1', 'p2', 'p3'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit())
ctx = callback.contexts[0]
self.assertIsNotNone(ctx.snapshot.next_row_id)
for entry in ctx.commit_entries:
self.assertIsNotNone(entry.file.first_row_id)
total_rows = sum(e.file.row_count for e in ctx.commit_entries)
self.assertEqual(3, total_rows)
self.assertEqual(total_rows, ctx.snapshot.next_row_id)
table_write.close()
table_commit.close()
def test_data_evolution_callback_row_id_increments_across_commits(self):
table = self._create_table('test_de_row_id_incr', options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
write_builder = table.new_stream_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
callback = RecordingCallback()
table_commit.add_commit_callback(callback)
for i in range(3):
data = pa.Table.from_pydict({
'id': [i * 2, i * 2 + 1],
'name': [f'a{i}', f'b{i}'],
'dt': ['p1', 'p1'],
}, schema=self.pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit(i), commit_identifier=i)
self.assertEqual(3, len(callback.contexts))
# Row IDs must be assigned and monotonically increasing across commits
prev_next_row_id = 0
for ctx in callback.contexts:
for entry in ctx.commit_entries:
self.assertIsNotNone(entry.file.first_row_id)
self.assertGreaterEqual(entry.file.first_row_id, prev_next_row_id)
self.assertGreater(ctx.snapshot.next_row_id, prev_next_row_id)
prev_next_row_id = ctx.snapshot.next_row_id
table_write.close()
table_commit.close()
if __name__ == '__main__':
unittest.main()