blob: 0b86b59138f139c5214de6d407e6dc0ae86a4ef3 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Unit tests for ``KeyValueDataWriter`` buffer behaviour.
Covers the fold algorithm (`_merge_pending_by_pk`), the flush lifecycle
(`_flush_all` empties the buffer + clears pending_data), and the
roll-write helper (`_roll_write` splits oversized buffers across
multiple files). Drives a thin harness that bypasses
``DataWriter.__init__`` so tests can exercise these paths without
spinning up the real catalog/write stack.
"""
import unittest
from unittest.mock import Mock
import pyarrow as pa
from pypaimon.read.reader.deduplicate_merge_function import \
DeduplicateMergeFunction
from pypaimon.read.reader.partial_update_merge_function import \
PartialUpdateMergeFunction
from pypaimon.write.writer.key_value_data_writer import KeyValueDataWriter
# Layout matches what ``KeyValueDataWriter._add_system_fields`` emits:
# ``[_KEY_id, _SEQUENCE_NUMBER, _VALUE_KIND, id, a, b]``. ``id`` is
# duplicated on the value side because the value layout in Paimon's
# row tuple includes every original column.
_SCHEMA = pa.schema([
pa.field('_KEY_id', pa.int64(), nullable=False),
pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False),
pa.field('_VALUE_KIND', pa.int8(), nullable=False),
pa.field('id', pa.int64(), nullable=False),
pa.field('a', pa.string()),
pa.field('b', pa.string()),
])
def _row(pk, seq, a, b):
return {
'_KEY_id': pk,
'_SEQUENCE_NUMBER': seq,
'_VALUE_KIND': 0,
'id': pk,
'a': a,
'b': b,
}
class _Harness(KeyValueDataWriter):
"""Bypass ``DataWriter.__init__`` to keep tests focused.
Provides just the attributes ``_merge_pending_by_pk`` / ``_flush_all``
/ ``_roll_write`` read, plus a recording stub for
``_write_data_to_file`` so the roll-write path can be exercised
without touching the filesystem.
"""
def __init__(self, merge_function, target_file_size: int = 10 ** 12):
self.trimmed_primary_keys = ['id']
self._merge_function = merge_function
# Large enough that ``_check_and_roll_if_needed`` does not
# trigger on its own in tests that don't care about rolling.
self.target_file_size = target_file_size
self.pending_data = None
self.committed_files = []
self.written_chunks = []
def _write_data_to_file(self, data):
# Record each chunk instead of writing to disk; mirrors the
# base writer's contract of appending to ``committed_files``.
self.written_chunks.append(data)
class WriteMergeBufferTest(unittest.TestCase):
# -- deduplicate ------------------------------------------------------
def test_dedupe_collapses_same_pk_run_to_latest(self):
writer = _Harness(DeduplicateMergeFunction())
data = pa.Table.from_pylist(
[_row(1, 1, 'old', None), _row(1, 2, 'new', None)],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.num_rows, 1)
self.assertEqual(
out.to_pylist(),
[_row(1, 2, 'new', None)],
)
def test_dedupe_keeps_disjoint_keys(self):
writer = _Harness(DeduplicateMergeFunction())
data = pa.Table.from_pylist(
[_row(1, 1, 'A', None),
_row(2, 2, 'B', None),
_row(3, 3, 'C', None)],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.num_rows, 3)
self.assertEqual(
sorted(out.to_pylist(), key=lambda r: r['id']),
[_row(1, 1, 'A', None),
_row(2, 2, 'B', None),
_row(3, 3, 'C', None)],
)
# -- partial-update ---------------------------------------------------
def _partial_update(self):
# Value-side carries 3 columns (id, a, b). The PK column ``id``
# is duplicated into the value side so partial-update can apply
# last-non-null semantics uniformly across every original
# user column.
return PartialUpdateMergeFunction(
key_arity=1, value_arity=3, nullables=[True, True, True])
def test_partial_update_merges_non_null_per_field(self):
writer = _Harness(self._partial_update())
data = pa.Table.from_pylist(
[_row(1, 1, 'A', None), _row(1, 2, None, 'B')],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.num_rows, 1)
self.assertEqual(out.to_pylist(), [_row(1, 2, 'A', 'B')])
def test_partial_update_three_writes_compose(self):
writer = _Harness(self._partial_update())
data = pa.Table.from_pylist(
[_row(1, 1, 'A', None),
_row(1, 2, None, 'B'),
_row(1, 3, None, None)],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.to_pylist(), [_row(1, 3, 'A', 'B')])
def test_partial_update_later_null_does_not_clobber_earlier_value(self):
writer = _Harness(self._partial_update())
data = pa.Table.from_pylist(
[_row(1, 1, 'KEEP', 'B'), _row(1, 2, None, None)],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.to_pylist(), [_row(1, 2, 'KEEP', 'B')])
# -- edge cases -------------------------------------------------------
def test_empty_buffer_returns_empty(self):
writer = _Harness(DeduplicateMergeFunction())
empty = pa.Table.from_pylist([], schema=_SCHEMA)
out = writer._merge_pending_by_pk(empty)
self.assertEqual(out.num_rows, 0)
def test_single_row_buffer_skips_merge(self):
# Mock to confirm the merge function isn't invoked: a single
# row cannot have duplicates, so we sidestep the to_pylist
# round-trip.
mock_mf = Mock()
writer = _Harness(mock_mf)
data = pa.Table.from_pylist(
[_row(1, 1, 'X', None)], schema=_SCHEMA)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.num_rows, 1)
mock_mf.reset.assert_not_called()
mock_mf.add.assert_not_called()
mock_mf.get_result.assert_not_called()
def test_get_result_none_drops_pk_run(self):
# Future-proof: contract says ``get_result`` returning ``None``
# means the entire PK group should be dropped.
class DropAll:
def reset(self):
pass
def add(self, _):
pass
def get_result(self):
return None
writer = _Harness(DropAll())
data = pa.Table.from_pylist(
[_row(1, 1, 'A', None), _row(1, 2, 'B', None)],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(out.num_rows, 0)
# -- KeyValue pooling -------------------------------------------------
def test_keyvalue_pool_does_not_alias_results_across_runs(self):
# Pooling reuses one KeyValue across the whole fold. If the
# PartialUpdateMergeFunction's get_result snapshotting were
# broken, run 1's result would mutate when run 2's data is
# written into the pooled instance. This test would catch that
# regression: build a buffer with two distinct PK runs and
# verify both results stand on their own.
writer = _Harness(self._partial_update())
data = pa.Table.from_pylist(
[_row(1, 1, 'A1', None),
_row(1, 2, None, 'B1'),
_row(2, 3, 'A2', None),
_row(2, 4, None, 'B2')],
schema=_SCHEMA,
)
out = writer._merge_pending_by_pk(data)
self.assertEqual(
sorted(out.to_pylist(), key=lambda r: r['id']),
[_row(1, 2, 'A1', 'B1'), _row(2, 4, 'A2', 'B2')],
)
# -- _process_data (no longer sorts) ----------------------------------
def test_process_data_adds_system_fields_without_sorting(self):
# Deferred-sort design: ``_process_data`` must not pre-sort the
# incoming batch. The global sort happens once inside
# ``_flush_all`` over the concatenated buffer.
writer = _Harness(DeduplicateMergeFunction())
writer.sequence_generator = _StubSeqGen()
# Intentionally pass rows in descending PK order; if _process_data
# were still sorting, the output would come back ascending.
batch = pa.RecordBatch.from_pylist(
[{'id': 3, 'a': 'C', 'b': None},
{'id': 1, 'a': 'A', 'b': None},
{'id': 2, 'a': 'B', 'b': None}],
schema=pa.schema([
pa.field('id', pa.int64(), nullable=False),
pa.field('a', pa.string()),
pa.field('b', pa.string()),
]),
)
out = writer._process_data(batch)
self.assertEqual(
[r['id'] for r in out.to_pylist()],
[3, 1, 2],
)
# -- _flush_all -------------------------------------------------------
def test_flush_all_sorts_folds_and_writes_one_file(self):
# Buffer with duplicate PKs in arbitrary order. _flush_all is
# responsible for sorting before folding, so unsorted input is
# the right stress case.
writer = _Harness(DeduplicateMergeFunction())
writer.pending_data = pa.Table.from_pylist(
[_row(2, 5, 'B2-new', None),
_row(1, 2, 'A1-mid', None),
_row(1, 1, 'A1-old', None),
_row(2, 4, 'B2-old', None),
_row(1, 3, 'A1-new', None)],
schema=_SCHEMA,
)
writer._flush_all()
# Buffer cleared.
self.assertIsNone(writer.pending_data)
# Exactly one file written (size well under target).
self.assertEqual(len(writer.written_chunks), 1)
flushed = writer.written_chunks[0]
result = sorted(flushed.to_pylist(), key=lambda r: r['id'])
# Dedup -> 1 row per PK, with the highest seq value retained.
self.assertEqual(result, [
_row(1, 3, 'A1-new', None),
_row(2, 5, 'B2-new', None),
])
def test_flush_all_on_empty_buffer_is_noop(self):
writer = _Harness(DeduplicateMergeFunction())
writer.pending_data = None
writer._flush_all()
self.assertIsNone(writer.pending_data)
self.assertEqual(writer.written_chunks, [])
def test_flush_all_clears_buffer_even_when_fold_drops_everything(self):
# MergeFunction that returns None for every group; verifies
# ``_flush_all`` still resets ``pending_data`` so a subsequent
# write starts from a clean slate.
class DropAll:
def reset(self):
pass
def add(self, _):
pass
def get_result(self):
return None
writer = _Harness(DropAll())
writer.pending_data = pa.Table.from_pylist(
[_row(1, 1, 'A', None), _row(1, 2, 'B', None)],
schema=_SCHEMA,
)
writer._flush_all()
self.assertIsNone(writer.pending_data)
self.assertEqual(writer.written_chunks, [])
# -- _roll_write ------------------------------------------------------
def test_roll_write_single_chunk_when_under_target(self):
writer = _Harness(DeduplicateMergeFunction(),
target_file_size=10 ** 9)
data = pa.Table.from_pylist(
[_row(1, 1, 'A', None), _row(2, 2, 'B', None)],
schema=_SCHEMA,
)
writer._roll_write(data)
self.assertEqual(len(writer.written_chunks), 1)
self.assertEqual(writer.written_chunks[0].num_rows, 2)
def test_roll_write_splits_oversized_buffer_into_multiple_files(self):
# Build a buffer whose nbytes comfortably exceeds the chosen
# target. With a small target_file_size the writer should hand
# back at least two files. Use long strings so nbytes scales
# predictably with row count.
rows = [
_row(i, i, 'x' * 64, 'y' * 64) for i in range(1, 401)
]
data = pa.Table.from_pylist(rows, schema=_SCHEMA)
# Target small enough that 400 rows will not fit in one file.
target = data.nbytes // 4
writer = _Harness(DeduplicateMergeFunction(),
target_file_size=target)
writer._roll_write(data)
self.assertGreaterEqual(len(writer.written_chunks), 2)
total_rows = sum(c.num_rows for c in writer.written_chunks)
self.assertEqual(total_rows, data.num_rows)
# Each chunk except possibly the last should respect the target.
for chunk in writer.written_chunks[:-1]:
self.assertLessEqual(chunk.nbytes, target)
class _StubSeqGen:
"""Stand-in for ``SequenceGenerator`` so the harness can call
``_process_data`` without going through the real ``DataWriter.__init__``.
"""
def __init__(self):
self._n = 0
def next(self) -> int:
self._n += 1
return self._n
if __name__ == '__main__':
unittest.main()