blob: 0d7dd568c0fc207afa59aca80b4cdaca26c65e5b [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""End-to-end tests for HASH_FIXED Ray writes.
For append-only HASH_FIXED tables, ``write_paimon`` writes rows to the
correct bucket by default without pre-clustering. HASH_FIXED
primary-key tables fail fast unless the legacy ``map_groups`` mode is
explicitly selected. These tests cover:
* default roundtrip correctness on an append-only HASH_FIXED table.
* default fail-fast behaviour on a HASH_FIXED PK table.
* roundtrip correctness on a partitioned HASH_FIXED PK table.
* explicit ``map_groups`` mode strips the transient bucket column
from the sink-visible schema.
* explicit ``map_groups`` mode can produce one file per
(partition, bucket) on the small test dataset.
* regression: a table whose schema already contains a column named
``__paimon_bucket__`` still works (collision-safe column name).
* non-HASH_FIXED append-only tables pass through unchanged.
* dynamic-bucket primary-key tables fail fast, while postpone-bucket
primary-key tables pass through.
"""
import glob
import os
import shutil
import tempfile
import unittest
import pyarrow as pa
import ray
from pypaimon import CatalogFactory, Schema
class RayShuffleTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog_options = {'warehouse': cls.warehouse}
catalog = CatalogFactory.create(cls.catalog_options)
catalog.create_database('default', True)
if not ray.is_initialized():
# 4 CPUs gives us enough room to actually fan a multi-block
# write across multiple workers so the "small-file" claim
# is observable.
ray.init(ignore_reinit_error=True, num_cpus=4)
@classmethod
def tearDownClass(cls):
try:
if ray.is_initialized():
ray.shutdown()
except Exception:
pass
try:
shutil.rmtree(cls.tempdir)
except OSError:
pass
def _make_table(self, table_name, pa_schema, *, primary_keys=None,
partition_keys=None, options=None):
identifier = 'default.{}'.format(table_name)
schema = Schema.from_pyarrow_schema(
pa_schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
options=options,
)
catalog = CatalogFactory.create(self.catalog_options)
catalog.create_table(identifier, schema, False)
return identifier
def _read_table(self, identifier):
"""Read table data via the direct API (not ``read_paimon``).
This avoids going through ``RayDatasource._get_read_task`` which
has a pre-existing strict nullability check (``from_batches``
with Paimon schema) that rejects batches where the reader drops
``not null`` (a raw-convertible PK split issue). Shuffle tests
care about *write* correctness, not the Ray read path.
"""
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
arrow = rb.new_read().to_arrow(splits)
return arrow.to_pandas() if arrow is not None else pa.table({}).to_pandas()
def _count_data_files(self, table_name):
"""All data files under the table directory, regardless of partition."""
root = os.path.join(self.warehouse, 'default.db', table_name)
patterns = ['*.parquet', '*.orc', '*.avro']
files = []
for pattern in patterns:
files.extend(glob.glob(
os.path.join(root, '**', 'bucket-*', pattern), recursive=True,
))
return files
# ----- HASH_FIXED writes -----
def test_append_only_fixed_bucket_roundtrip(self):
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
])
table_name = 'test_append_only_fixed_bucket_roundtrip'
identifier = self._make_table(
table_name, pa_schema, options={'bucket': '4'},
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
write_paimon(ds, identifier, self.catalog_options)
result = self._read_table(identifier)
self.assertEqual(len(result), 40)
self.assertEqual(set(result['id']), set(range(40)))
self.assertNotIn('__paimon_bucket__', result.columns)
def test_primary_key_fixed_bucket_default_fails_fast(self):
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
table_name = 'test_pk_fixed_bucket_default_fails_fast'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id'], options={'bucket': '4'},
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
with self.assertRaisesRegex(ValueError, "HASH_FIXED primary-key"):
write_paimon(ds, identifier, self.catalog_options)
def test_table_write_ray_primary_key_fixed_bucket_default_fails_fast(self):
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
table_name = 'test_table_write_ray_pk_default_fails_fast'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id'], options={'bucket': '4'},
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
writer = table.new_batch_write_builder().new_write()
try:
with self.assertRaisesRegex(ValueError, "HASH_FIXED primary-key"):
writer.write_ray(ds)
finally:
writer.close()
def test_primary_key_dynamic_bucket_default_fails_fast(self):
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
table_name = 'test_pk_dynamic_bucket_default_fails_fast'
identifier = self._make_table(
table_name, pa_schema, primary_keys=['id'],
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"):
write_paimon(ds, identifier, self.catalog_options)
def test_table_write_ray_primary_key_dynamic_bucket_default_fails_fast(self):
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
table_name = 'test_table_write_ray_pk_dynamic_default_fails_fast'
identifier = self._make_table(
table_name, pa_schema, primary_keys=['id'],
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
writer = table.new_batch_write_builder().new_write()
try:
with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"):
writer.write_ray(ds)
finally:
writer.close()
def test_primary_key_postpone_bucket_roundtrip_to_postpone_files(self):
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('dt', pa.string()),
('value', pa.int64()),
])
table_name = 'test_pk_postpone_bucket_ray_write'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id', 'dt'], partition_keys=['dt'],
options={'bucket': '-2'},
)
rows = pa.Table.from_pydict({
'id': list(range(10)),
'dt': ['2026-01-01'] * 5 + ['2026-01-02'] * 5,
'value': list(range(10)),
}, schema=pa_schema)
write_paimon(
ray.data.from_arrow(rows).repartition(2),
identifier,
self.catalog_options,
)
files = self._count_data_files(table_name)
self.assertGreater(len(files), 0)
self.assertTrue(all('/bucket-postpone/' in path for path in files))
self.assertEqual(len(self._read_table(identifier)), 0)
def test_partitioned_fixed_bucket_roundtrip(self):
"""Partitioned table — confirms the post-groupby schema does not
end up with duplicated partition-key or bucket columns."""
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('dt', pa.string()),
('value', pa.int64()),
])
table_name = 'test_partitioned_fixed_bucket_roundtrip'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id', 'dt'], partition_keys=['dt'],
options={'bucket': '4'},
)
rows = pa.Table.from_pydict({
'id': list(range(20)),
'dt': ['2026-01-01'] * 10 + ['2026-01-02'] * 10,
'value': list(range(20)),
}, schema=pa_schema)
ds = ray.data.from_arrow(rows).repartition(4)
write_paimon(
ds,
identifier,
self.catalog_options,
hash_fixed_precluster="map_groups",
)
result = self._read_table(identifier)
self.assertEqual(set(result.columns), {'id', 'dt', 'value'})
self.assertEqual(len(result), 20)
self.assertEqual(set(result['dt']), {'2026-01-01', '2026-01-02'})
def test_table_write_ray_primary_key_fixed_bucket_map_groups_roundtrip(self):
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
table_name = 'test_table_write_ray_pk_map_groups'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id'], options={'bucket': '4'},
)
rows = pa.Table.from_pydict(
{'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(4)
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
writer = table.new_batch_write_builder().new_write()
try:
writer.write_ray(ds, hash_fixed_precluster="map_groups")
finally:
writer.close()
result = self._read_table(identifier)
self.assertEqual(len(result), 40)
self.assertEqual(set(result['id']), set(range(40)))
def test_fixed_bucket_writes_one_file_per_bucket(self):
"""With multiple input blocks, explicit map_groups clustering
collapses per-task files into per-bucket files."""
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('value', pa.int64()),
])
rows = pa.Table.from_pydict(
{'id': list(range(200)), 'value': list(range(200))},
schema=pa_schema,
)
identifier = self._make_table(
'test_one_file_per_bucket', pa_schema,
primary_keys=['id'], options={'bucket': '4'},
)
# Materialise 4 input blocks. Without the explicit map_groups
# mode, each task would emit one file per bucket it touched.
write_paimon(
ray.data.from_arrow(rows).repartition(4),
identifier, self.catalog_options,
hash_fixed_precluster="map_groups",
)
files = self._count_data_files('test_one_file_per_bucket')
# 4 buckets × 1 file each.
self.assertEqual(len(files), 4)
def test_fixed_bucket_with_colliding_column_name(self):
"""A table that has a column named ``__paimon_bucket__`` must
still work — the helper picks a collision-free transient
column name."""
from pypaimon.ray import write_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('__paimon_bucket__', pa.string()),
])
table_name = 'test_fixed_bucket_collide_col'
identifier = self._make_table(
table_name, pa_schema,
primary_keys=['id'], options={'bucket': '2'},
)
rows = pa.Table.from_pydict(
{'id': list(range(10)),
'__paimon_bucket__': [f'v{i}' for i in range(10)]},
schema=pa_schema,
)
ds = ray.data.from_arrow(rows).repartition(2)
write_paimon(
ds,
identifier,
self.catalog_options,
hash_fixed_precluster="map_groups",
)
result = self._read_table(identifier)
self.assertEqual(len(result), 10)
self.assertEqual(set(result.columns), {'id', '__paimon_bucket__'})
# ----- non-HASH_FIXED passthrough -----
def test_non_fixed_bucket_roundtrip(self):
"""BUCKET_UNAWARE tables are written without pre-clustering;
roundtrip data must still be correct."""
from pypaimon.ray import read_paimon, write_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('value', pa.int64()),
])
# bucket=-1 + no primary keys → BUCKET_UNAWARE
table_name = 'test_non_fixed_bucket_roundtrip'
identifier = self._make_table(
table_name, pa_schema, options={'bucket': '-1'},
)
rows = pa.Table.from_pydict(
{'id': list(range(10)), 'value': list(range(10))},
schema=pa_schema,
)
write_paimon(
ray.data.from_arrow(rows), identifier, self.catalog_options,
)
result = read_paimon(identifier, self.catalog_options).to_pandas()
self.assertEqual(len(result), 10)
self.assertEqual(set(result['id']), set(range(10)))
if __name__ == '__main__':
unittest.main()