blob: be50cfa637d135e4af3372e22fecaa247ecc9f98 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for ChunkShuffleSplitGenerator and TableScan.with_chunk_shuffle.
Algorithmic tests use Mock entries so they don't touch disk; the
end-to-end test writes a real append table and validates that all
workers together cover the data exactly once.
"""
import os
import shutil
import tempfile
import unittest
from unittest.mock import Mock
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.globalindex.indexed_split import IndexedSplit
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
from pypaimon.read.scanner.chunk_shuffle_split_generator import (
AppendChunkShuffleSplitGenerator,
DataEvolutionChunkShuffleSplitGenerator,
)
from pypaimon.read.sliced_split import SlicedSplit
from pypaimon.read.split import DataSplit
from pypaimon.utils.range import Range
def _mock_table(table_path='/tmp/_chunk_shuffle_test_path'):
table = Mock()
table.table_path = table_path
table.options = Mock()
return table
def _mock_entry(partition_values, bucket, file_name, row_count, file_size=1024):
entry = Mock()
entry.partition = Mock()
entry.partition.values = partition_values
entry.bucket = bucket
entry.file = Mock()
entry.file.file_name = file_name
entry.file.file_size = file_size
entry.file.row_count = row_count
# Swallow set_file_path so we don't need to mock partition path encoding.
entry.file.set_file_path = Mock()
return entry
def _make_generator(seed, chunk_size, table=None):
if table is None:
table = _mock_table()
return AppendChunkShuffleSplitGenerator(
table,
target_split_size=128 * 1024 * 1024,
open_file_cost=4 * 1024 * 1024,
deletion_files_map=None,
seed=seed,
chunk_size=chunk_size,
)
def _make_de_generator(seed, chunk_size, table=None):
if table is None:
table = _mock_table()
return DataEvolutionChunkShuffleSplitGenerator(
table,
target_split_size=128 * 1024 * 1024,
open_file_cost=4 * 1024 * 1024,
deletion_files_map=None,
seed=seed,
chunk_size=chunk_size,
)
def _mock_de_entry(partition_values, bucket, file_name, first_row_id, row_count, file_size=1024):
"""A DE-flavoured mock entry: file carries first_row_id and a real
Range so :meth:`row_id_range` and ``Range.overlaps`` work."""
entry = Mock()
entry.partition = Mock()
entry.partition.values = partition_values
entry.bucket = bucket
file = Mock(spec=DataFileMeta)
file.file_name = file_name
file.file_size = file_size
file.row_count = row_count
file.first_row_id = first_row_id
file.row_id_range = lambda f=first_row_id, c=row_count: Range(f, f + c - 1)
file.set_file_path = Mock()
entry.file = file
return entry
def _split_signature(split):
"""A stable, comparable identity for a split — what the worker would actually read."""
if isinstance(split, SlicedSplit):
underlying = split.data_split()
files = tuple(f.file_name for f in underlying.files)
idx_map = tuple(sorted(split.shard_file_idx_map().items()))
return (tuple(underlying.partition.values), underlying.bucket, files, idx_map)
if isinstance(split, IndexedSplit):
underlying = split.data_split()
files = tuple(sorted(f.file_name for f in underlying.files))
ranges = tuple((r.from_, r.to) for r in split.row_ranges())
return (tuple(underlying.partition.values), underlying.bucket, files, ranges)
if isinstance(split, DataSplit):
files = tuple(f.file_name for f in split.files)
return (tuple(split.partition.values), split.bucket, files, ())
raise AssertionError("unexpected split type: %r" % type(split))
def _split_rows(split):
"""Effective row count this split actually exposes."""
return split.row_count
class ChunkShuffleSplitGeneratorAlgoTest(unittest.TestCase):
def test_no_entries_returns_empty(self):
gen = _make_generator(seed=1, chunk_size=100)
self.assertEqual(gen.create_splits([]), [])
def test_full_files_no_truncation(self):
entries = [
_mock_entry([], 0, 'f1', 100),
_mock_entry([], 0, 'f2', 100),
_mock_entry([], 0, 'f3', 100),
]
gen = _make_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
# 3 chunks, each holding exactly one whole file → all DataSplit, no SlicedSplit
self.assertEqual(len(splits), 3)
for s in splits:
self.assertIsInstance(s, DataSplit)
self.assertEqual(s.row_count, 100)
def test_chunk_truncates_inside_file(self):
# one file of 250 rows, chunk_size 100 → 3 chunks: 100, 100, 50
entries = [_mock_entry([], 0, 'f1', 250)]
gen = _make_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 3)
# All three chunks slice the same file → all SlicedSplit
for s in splits:
self.assertIsInstance(s, SlicedSplit)
# union of (start, end) intervals must cover [0, 250)
intervals = sorted(s.shard_file_idx_map()['f1'] for s in splits)
self.assertEqual(intervals, [(0, 100), (100, 200), (200, 250)])
total = sum(end - start for start, end in intervals)
self.assertEqual(total, 250)
def test_chunk_spans_multiple_files(self):
# f1=30, f2=30, f3=30, chunk_size=50 → chunks: [f1(30)+f2(0,20)], [f2(20,30)+f3(0,40 cap 30=30)] ...
entries = [
_mock_entry([], 0, 'f1', 30),
_mock_entry([], 0, 'f2', 30),
_mock_entry([], 0, 'f3', 30),
]
gen = _make_generator(seed=1, chunk_size=50)
splits = gen.create_splits(entries)
# total 90 rows, chunk_size 50 → 2 chunks (50 + 40)
self.assertEqual(len(splits), 2)
total_rows = sum(_split_rows(s) for s in splits)
self.assertEqual(total_rows, 90)
def test_chunk_size_larger_than_total(self):
entries = [
_mock_entry([], 0, 'f1', 30),
_mock_entry([], 0, 'f2', 30),
]
gen = _make_generator(seed=1, chunk_size=1000)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 1)
# No truncation — full files inside one chunk → DataSplit not SlicedSplit
self.assertIsInstance(splits[0], DataSplit)
self.assertEqual(_split_rows(splits[0]), 60)
def test_deterministic_same_seed_same_order(self):
entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(20)]
gen1 = _make_generator(seed=42, chunk_size=50)
gen2 = _make_generator(seed=42, chunk_size=50)
splits1 = gen1.create_splits(entries)
splits2 = gen2.create_splits(entries)
self.assertEqual(
[_split_signature(s) for s in splits1],
[_split_signature(s) for s in splits2],
)
def test_different_seed_different_order(self):
entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(50)]
gen1 = _make_generator(seed=1, chunk_size=100)
gen2 = _make_generator(seed=2, chunk_size=100)
sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
# Same set of chunks, different order — high probability they differ on 50 items
self.assertEqual(sorted(sigs1), sorted(sigs2))
self.assertNotEqual(sigs1, sigs2)
def test_shuffle_actually_reorders(self):
# 20 files in scan order f0..f19. After shuffle the file order should not be sorted.
entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(20)]
gen = _make_generator(seed=42, chunk_size=100)
splits = gen.create_splits(entries)
file_names = [s.files[0].file_name for s in splits]
self.assertNotEqual(file_names, sorted(file_names))
def test_shard_round_trip_no_overlap_no_loss(self):
# 13 files × 100 rows = 1300 rows. 4 workers.
entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(13)]
num_workers = 4
all_sigs = []
total_rows = 0
for worker in range(num_workers):
gen = _make_generator(seed=7, chunk_size=100)
gen.with_shard(worker, num_workers)
splits = gen.create_splits(list(entries)) # copy: shuffle is in-place on chunks list
for s in splits:
all_sigs.append(_split_signature(s))
total_rows += _split_rows(s)
self.assertEqual(total_rows, 13 * 100)
# No duplicate chunks across workers
self.assertEqual(len(all_sigs), len(set(all_sigs)))
# All chunks together equal an unsharded run
unsharded = _make_generator(seed=7, chunk_size=100).create_splits(list(entries))
self.assertEqual(
sorted(all_sigs),
sorted(_split_signature(s) for s in unsharded),
)
def test_shard_balanced_distribution(self):
# 10 chunks across 3 workers → 4, 3, 3 (front-loaded by _compute_shard_range)
entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(10)]
counts = []
for worker in range(3):
gen = _make_generator(seed=0, chunk_size=100)
gen.with_shard(worker, 3)
counts.append(len(gen.create_splits(list(entries))))
self.assertEqual(sorted(counts, reverse=True), [4, 3, 3])
def test_chunks_fewer_than_workers(self):
# 2 chunks, 5 workers → 3 workers get nothing
entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(2)]
empties = 0
non_empties = 0
for worker in range(5):
gen = _make_generator(seed=0, chunk_size=100)
gen.with_shard(worker, 5)
n = len(gen.create_splits(list(entries)))
if n == 0:
empties += 1
else:
non_empties += 1
self.assertEqual(n, 1)
self.assertEqual(empties, 3)
self.assertEqual(non_empties, 2)
def test_multi_partition_no_chunk_crosses_partition(self):
entries = [
_mock_entry(['p1'], 0, 'f1', 100),
_mock_entry(['p1'], 0, 'f2', 100),
_mock_entry(['p2'], 0, 'f3', 100),
_mock_entry(['p2'], 0, 'f4', 100),
]
gen = _make_generator(seed=0, chunk_size=100)
splits = gen.create_splits(entries)
# Each split's underlying files come from one partition only
for s in splits:
partitions_in_files = set()
data_split = s.data_split() if isinstance(s, SlicedSplit) else s
partitions_in_files.add(tuple(data_split.partition.values))
self.assertEqual(len(partitions_in_files), 1)
def test_null_and_non_null_partitions_sort_safely(self):
# Mixing null and non-null partition values used to raise
# ``TypeError: '<' not supported between instances of 'NoneType' and 'str'``
# before _null_safe_partition_key. Validate planning succeeds and
# both partitions produce splits.
entries = [
_mock_entry(['p1'], 0, 'f1', 100),
_mock_entry([None], 0, 'f2', 100),
_mock_entry(['p2'], 0, 'f3', 100),
]
gen = _make_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 3)
partitions = {tuple(_split_signature(s)[0]) for s in splits}
self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
def test_input_order_does_not_affect_output_when_same_files(self):
"""Manifest read parallelism shouldn't bleed through — sorting is internal."""
a = _mock_entry([], 0, 'f1', 100)
b = _mock_entry([], 0, 'f2', 100)
c = _mock_entry([], 0, 'f3', 100)
gen1 = _make_generator(seed=99, chunk_size=100)
gen2 = _make_generator(seed=99, chunk_size=100)
sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
self.assertEqual(sigs1, sigs2)
class ChunkShuffleEndToEndTest(unittest.TestCase):
"""Real append table → with_chunk_shuffle → multiple workers → union == original."""
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', True)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_append_table(self, name, partition_keys=None):
pa_schema = pa.schema([
('id', pa.int64()),
('value', pa.string()),
('part', pa.string()),
])
schema = Schema.from_pyarrow_schema(
pa_schema, partition_keys=partition_keys or [])
identifier = f'default.{name}'
self.catalog.create_table(identifier, schema, False)
return self.catalog.get_table(identifier), pa_schema
def _write_n_batches(self, table, pa_schema, batches):
wb = table.new_batch_write_builder()
for batch in batches:
tw = wb.new_write()
tc = wb.new_commit()
tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
tc.commit(tw.prepare_commit())
tw.close()
tc.close()
def test_workers_union_equals_full_table(self):
table, pa_schema = self._create_append_table('cs_union')
# 4 commits × 50 rows = 200 rows across several files
batches = []
for c in range(4):
base = c * 50
batches.append({
'id': list(range(base, base + 50)),
'value': [f'v{i}' for i in range(base, base + 50)],
'part': ['p1'] * 50,
})
self._write_n_batches(table, pa_schema, batches)
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
num_workers = 3
worker_tables = []
for w in range(num_workers):
scan = read_builder.new_scan() \
.with_chunk_shuffle(seed=123, chunk_size=37) \
.with_shard(w, num_workers)
splits = scan.plan().splits()
if splits:
worker_tables.append(table_read.to_arrow(splits))
actual = pa.concat_tables(worker_tables).sort_by('id') if worker_tables else None
self.assertIsNotNone(actual)
self.assertEqual(actual.num_rows, 200)
self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
def test_deterministic_plan_across_calls(self):
table, pa_schema = self._create_append_table('cs_determinism')
self._write_n_batches(table, pa_schema, [{
'id': list(range(100)),
'value': [f'v{i}' for i in range(100)],
'part': ['p'] * 100,
}])
def plan_files(worker):
scan = table.new_read_builder().new_scan() \
.with_chunk_shuffle(seed=42, chunk_size=20) \
.with_shard(worker, 3)
return [_split_signature(s) for s in scan.plan().splits()]
for worker in range(3):
self.assertEqual(plan_files(worker), plan_files(worker))
class ChunkShuffleCompatibilityTest(unittest.TestCase):
"""Validates the reject-on-incompatible-combination matrix."""
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', True)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _append_table(self, name, options=None, partition_keys=None):
if partition_keys:
pa_schema = pa.schema([
('id', pa.int64()),
('value', pa.string()),
('part', pa.string()),
])
else:
pa_schema = pa.schema([('id', pa.int64()), ('value', pa.string())])
schema = Schema.from_pyarrow_schema(
pa_schema, partition_keys=partition_keys, options=options or {})
self.catalog.create_table(f'default.{name}', schema, False)
return self.catalog.get_table(f'default.{name}')
def _pk_table(self, name):
pa_schema = pa.schema([
pa.field('id', pa.int64(), nullable=False),
('value', pa.string()),
])
schema = Schema.from_pyarrow_schema(
pa_schema, primary_keys=['id'], options={'bucket': '1'})
self.catalog.create_table(f'default.{name}', schema, False)
return self.catalog.get_table(f'default.{name}')
def test_pk_table_rejected(self):
table = self._pk_table('cs_pk')
scan = table.new_read_builder().new_scan()
scan.with_chunk_shuffle(seed=1, chunk_size=100)
with self.assertRaisesRegex(ValueError, "only supports append tables"):
scan.plan()
def test_dv_table_rejected(self):
table = self._append_table('cs_dv', options={'deletion-vectors.enabled': 'true'})
scan = table.new_read_builder().new_scan()
scan.with_chunk_shuffle(seed=1, chunk_size=100)
with self.assertRaisesRegex(ValueError, "deletion vectors"):
scan.plan()
def test_with_slice_then_chunk_shuffle_rejected(self):
table = self._append_table('cs_slice')
scan = table.new_read_builder().new_scan()
scan.with_slice(0, 100).with_chunk_shuffle(seed=1, chunk_size=100)
with self.assertRaisesRegex(ValueError, "with_slice"):
scan.plan()
def test_limit_with_chunk_shuffle_rejected(self):
table = self._append_table('cs_limit')
scan = table.new_read_builder().with_limit(50).new_scan()
scan.with_chunk_shuffle(seed=1, chunk_size=100)
with self.assertRaisesRegex(ValueError, "limit"):
scan.plan()
def test_invalid_chunk_size(self):
table = self._append_table('cs_invalid')
scan = table.new_read_builder().new_scan()
with self.assertRaisesRegex(ValueError, "chunk_size"):
scan.with_chunk_shuffle(seed=1, chunk_size=0)
with self.assertRaisesRegex(ValueError, "chunk_size"):
scan.with_chunk_shuffle(seed=1, chunk_size=-5)
def test_column_predicate_rejected(self):
# Non-partition predicate would silently shrink effective chunk
# row counts inside the reader → not allowed.
table = self._append_table('cs_col_pred', partition_keys=['part'])
rb = table.new_read_builder()
col_pred = rb.new_predicate_builder().equal('id', 5)
rb = rb.with_filter(col_pred)
scan = rb.new_scan().with_chunk_shuffle(seed=1, chunk_size=10)
with self.assertRaisesRegex(ValueError, "partition keys"):
scan.plan()
def test_partition_predicate_allowed(self):
# Filter is partition-only → must succeed and read only the
# matching partition.
table, pa_schema = self._partitioned_table_with_data('cs_part_pred')
rb = table.new_read_builder()
pred = rb.new_predicate_builder().equal('part', 'p1')
scan = rb.with_filter(pred).new_scan() \
.with_chunk_shuffle(seed=1, chunk_size=10)
plan = scan.plan()
# All splits should be from partition 'p1'
for split in plan.splits():
partition_values = split.partition.values
self.assertEqual(tuple(partition_values), ('p1',))
def _partitioned_table_with_data(self, name):
pa_schema = pa.schema([
('id', pa.int64()),
('value', pa.string()),
('part', pa.string()),
])
schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['part'])
identifier = f'default.{name}'
self.catalog.create_table(identifier, schema, False)
table = self.catalog.get_table(identifier)
wb = table.new_batch_write_builder()
for part, ids in [('p1', range(50)), ('p2', range(50, 100))]:
tw = wb.new_write()
tc = wb.new_commit()
tw.write_arrow(pa.Table.from_pydict(
{'id': list(ids),
'value': [f'v{i}' for i in ids],
'part': [part] * 50},
schema=pa_schema,
))
tc.commit(tw.prepare_commit())
tw.close()
tc.close()
return table, pa_schema
class DataEvolutionChunkShuffleAlgoTest(unittest.TestCase):
"""Mock-based tests for the DE chunk slicer."""
def test_no_entries_returns_empty(self):
gen = _make_de_generator(seed=1, chunk_size=100)
self.assertEqual(gen.create_splits([]), [])
def test_full_aligned_groups_one_per_chunk(self):
# Three commits of 100 rows each → three aligned groups.
# chunk_size = 100 → 3 chunks, each holding one group whole.
entries = [
_mock_de_entry([], 0, 'g0.parquet', 0, 100),
_mock_de_entry([], 0, 'g1.parquet', 100, 100),
_mock_de_entry([], 0, 'g2.parquet', 200, 100),
]
gen = _make_de_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 3)
for s in splits:
self.assertIsInstance(s, IndexedSplit)
self.assertEqual(s.row_count, 100)
self.assertEqual(len(s.row_ranges()), 1)
def test_aligned_group_split_across_chunks(self):
# One 250-row group, chunk_size=100 → 3 chunks (100, 100, 50).
# All three chunks reference the SAME aligned group's files but
# each carries a distinct row_range slice.
entries = [_mock_de_entry([], 0, 'g0.parquet', 1000, 250)]
gen = _make_de_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 3)
# Union of the three chunks' row_ranges must cover the whole group [1000, 1249].
ranges = []
for s in splits:
self.assertIsInstance(s, IndexedSplit)
ranges.extend((r.from_, r.to) for r in s.row_ranges())
ranges.sort()
self.assertEqual(ranges, [(1000, 1099), (1100, 1199), (1200, 1249)])
total = sum(r[1] - r[0] + 1 for r in ranges)
self.assertEqual(total, 250)
def test_chunk_pulls_in_blob_siblings(self):
# One aligned group with a main parquet and a blob sibling sharing the
# row_id range. A single chunk must include BOTH files so the reader
# can union the columns.
entries = [
_mock_de_entry([], 0, 'g0.parquet', 0, 100),
_mock_de_entry([], 0, 'g0.blob', 0, 100), # .blob ext → is_blob_file
]
gen = _make_de_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 1)
files = sorted(f.file_name for f in splits[0].files)
self.assertEqual(files, ['g0.blob', 'g0.parquet'])
def test_blob_propagates_when_group_split(self):
# Same scenario but chunk_size halves the group → the blob sibling
# must appear in BOTH chunk splits.
entries = [
_mock_de_entry([], 0, 'g0.parquet', 0, 100),
_mock_de_entry([], 0, 'g0.blob', 0, 100),
]
gen = _make_de_generator(seed=1, chunk_size=50)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 2)
for s in splits:
files = sorted(f.file_name for f in s.files)
self.assertEqual(files, ['g0.blob', 'g0.parquet'])
def test_deterministic_same_seed(self):
entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(20)]
gen1 = _make_de_generator(seed=42, chunk_size=100)
gen2 = _make_de_generator(seed=42, chunk_size=100)
sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
self.assertEqual(sigs1, sigs2)
def test_different_seed_reorders(self):
entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(50)]
gen1 = _make_de_generator(seed=1, chunk_size=100)
gen2 = _make_de_generator(seed=2, chunk_size=100)
sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
self.assertEqual(sorted(sigs1), sorted(sigs2))
self.assertNotEqual(sigs1, sigs2)
def test_input_order_does_not_affect_output(self):
a = _mock_de_entry([], 0, 'g0.parquet', 0, 100)
b = _mock_de_entry([], 0, 'g1.parquet', 100, 100)
c = _mock_de_entry([], 0, 'g2.parquet', 200, 100)
gen1 = _make_de_generator(seed=99, chunk_size=100)
gen2 = _make_de_generator(seed=99, chunk_size=100)
sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
self.assertEqual(sigs1, sigs2)
def test_shard_round_trip_no_overlap_no_loss(self):
# 13 aligned groups × 100 rows = 1300 rows. Shard across 4 workers.
entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(13)]
num_workers = 4
unsharded = _make_de_generator(seed=7, chunk_size=100).create_splits(list(entries))
unsharded_sigs = sorted(_split_signature(s) for s in unsharded)
sharded_sigs = []
total_rows = 0
for w in range(num_workers):
gen = _make_de_generator(seed=7, chunk_size=100)
gen.with_shard(w, num_workers)
for s in gen.create_splits(list(entries)):
sharded_sigs.append(_split_signature(s))
total_rows += s.row_count
self.assertEqual(total_rows, 13 * 100)
# No duplicate splits across workers
self.assertEqual(len(sharded_sigs), len(set(sharded_sigs)))
self.assertEqual(sorted(sharded_sigs), unsharded_sigs)
def test_multi_partition_no_chunk_crosses_partition(self):
entries = [
_mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
_mock_de_entry(['p1'], 0, 'g1.parquet', 100, 100),
_mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
_mock_de_entry(['p2'], 0, 'g3.parquet', 300, 100),
]
gen = _make_de_generator(seed=0, chunk_size=100)
splits = gen.create_splits(entries)
for s in splits:
data_split = s.data_split() if isinstance(s, IndexedSplit) else s
self.assertEqual(len({tuple(data_split.partition.values)}), 1)
def test_null_and_non_null_partitions_sort_safely(self):
# Same null-vs-non-null sort guard, exercised on the DE path.
entries = [
_mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
_mock_de_entry([None], 0, 'g1.parquet', 100, 100),
_mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
]
gen = _make_de_generator(seed=1, chunk_size=100)
splits = gen.create_splits(entries)
self.assertEqual(len(splits), 3)
partitions = {_split_signature(s)[0] for s in splits}
self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
class DataEvolutionChunkShuffleEndToEndTest(unittest.TestCase):
"""Real DE table → with_chunk_shuffle → multi-worker → union == full table."""
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', True)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_de_table(self, name):
pa_schema = pa.schema([
('id', pa.int32()),
('value', pa.string()),
('payload', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(
pa_schema,
options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
'blob.target-file-size': '1 b',
},
)
identifier = f'default.{name}'
self.catalog.create_table(identifier, schema, False)
return self.catalog.get_table(identifier), pa_schema
@staticmethod
def _payloads(ids):
return [f'payload-{i:03d}'.encode('utf-8') for i in ids]
def _commit_full_rows(self, table, pa_schema, ids):
wb = table.new_batch_write_builder()
tw = wb.new_write()
tc = wb.new_commit()
tw.write_arrow(pa.Table.from_pydict(
{
'id': ids,
'value': [f'v{i}' for i in ids],
'payload': self._payloads(ids),
},
schema=pa_schema))
commit_messages = tw.prepare_commit()
tc.commit(commit_messages)
tw.close()
tc.close()
return commit_messages
def _assert_commit_has_main_and_multiple_blob_files(self, commit_messages):
all_files = [f for msg in commit_messages for f in msg.new_files]
main_files = [f for f in all_files if not DataFileMeta.is_blob_file(f.file_name)]
blob_files = [f for f in all_files if DataFileMeta.is_blob_file(f.file_name)]
self.assertGreaterEqual(len(main_files), 1)
self.assertGreater(
len(blob_files), 1,
"DE chunk-shuffle tests should exercise one row-id group with multiple blob files",
)
def _assert_splits_include_blob_files(self, splits):
self.assertGreater(len(splits), 0)
for split in splits:
data_split = split.data_split() if isinstance(split, IndexedSplit) else split
blob_files = [
f for f in data_split.files
if DataFileMeta.is_blob_file(f.file_name)
]
self.assertGreater(
len(blob_files), 0,
"Each DE chunk should keep blob sidecar files with its aligned row-id group",
)
def test_workers_union_equals_full_table(self):
table, pa_schema = self._create_de_table('cs_de_union')
# 4 commits → 4 aligned groups. Each group has one normal file and
# multiple blob sidecar files because blob.target-file-size is 1 byte.
for c in range(4):
base = c * 50
commit_messages = self._commit_full_rows(
table, pa_schema, list(range(base, base + 50)))
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
num_workers = 3
worker_tables = []
for w in range(num_workers):
scan = read_builder.new_scan() \
.with_chunk_shuffle(seed=123, chunk_size=37) \
.with_shard(w, num_workers)
splits = scan.plan().splits()
if splits:
self._assert_splits_include_blob_files(splits)
worker_tables.append(table_read.to_arrow(splits))
actual = pa.concat_tables(worker_tables).sort_by('id')
self.assertEqual(actual.num_rows, 200)
self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
self.assertEqual(actual.column('payload').to_pylist(), self._payloads(range(200)))
def test_deterministic_plan_across_calls(self):
table, pa_schema = self._create_de_table('cs_de_determinism')
for c in range(3):
base = c * 40
commit_messages = self._commit_full_rows(
table, pa_schema, list(range(base, base + 40)))
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
def plan_sigs(worker):
scan = table.new_read_builder().new_scan() \
.with_chunk_shuffle(seed=42, chunk_size=15) \
.with_shard(worker, 4)
splits = scan.plan().splits()
if splits:
self._assert_splits_include_blob_files(splits)
return [_split_signature(s) for s in splits]
for worker in range(4):
self.assertEqual(plan_sigs(worker), plan_sigs(worker))
if __name__ == '__main__':
unittest.main()