blob: f0664168ee36842133267ee6cf90b7405ededf8c [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import shutil
import tempfile
import unittest
import pyarrow as pa
import ray
from pypaimon import CatalogFactory, Schema
class RayIntegrationTest(unittest.TestCase):
"""Tests for the top-level read_paimon() / write_paimon() API."""
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog_options = {'warehouse': cls.warehouse}
catalog = CatalogFactory.create(cls.catalog_options)
catalog.create_database('default', True)
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, num_cpus=2)
@classmethod
def tearDownClass(cls):
try:
if ray.is_initialized():
ray.shutdown()
except Exception:
pass
try:
shutil.rmtree(cls.tempdir)
except OSError:
pass
def _create_and_populate_table(self, table_name, pa_schema, data_dict,
primary_keys=None, partition_keys=None, options=None):
"""Helper to create a table and write a single batch of data."""
identifier = 'default.{}'.format(table_name)
schema = Schema.from_pyarrow_schema(
pa_schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
options=options,
)
catalog = CatalogFactory.create(self.catalog_options)
catalog.create_table(identifier, schema, False)
table = catalog.get_table(identifier)
test_data = pa.Table.from_pydict(data_dict, schema=pa_schema)
write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.write_arrow(test_data)
commit_messages = writer.prepare_commit()
commit = write_builder.new_commit()
commit.commit(commit_messages)
writer.close()
return identifier
def test_read_paimon_basic(self):
"""read_paimon() reads back the data we wrote."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('value', pa.int64()),
])
identifier = self._create_and_populate_table(
'test_read_basic', pa_schema,
{'id': [1, 2, 3], 'name': ['a', 'b', 'c'], 'value': [10, 20, 30]},
)
ds = read_paimon(identifier, self.catalog_options, override_num_blocks=1)
self.assertEqual(ds.count(), 3)
df = ds.to_pandas().sort_values('id').reset_index(drop=True)
self.assertEqual(list(df['id']), [1, 2, 3])
self.assertEqual(list(df['name']), ['a', 'b', 'c'])
def test_read_paimon_with_projection(self):
"""read_paimon() respects column projection."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('value', pa.int64()),
])
identifier = self._create_and_populate_table(
'test_read_proj', pa_schema,
{'id': [1, 2], 'name': ['a', 'b'], 'value': [10, 20]},
)
ds = read_paimon(identifier, self.catalog_options, projection=['id', 'name'])
df = ds.to_pandas()
self.assertEqual(set(df.columns), {'id', 'name'})
self.assertEqual(len(df), 2)
def test_read_paimon_with_filter(self):
"""read_paimon() pushes down a predicate filter."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('category', pa.string()),
])
identifier = self._create_and_populate_table(
'test_read_filter', pa_schema,
{'id': [1, 2, 3], 'category': ['A', 'B', 'A']},
)
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
pb = table.new_read_builder().new_predicate_builder()
predicate = pb.equal('category', 'A')
ds = read_paimon(identifier, self.catalog_options, filter=predicate)
self.assertEqual(ds.count(), 2)
df = ds.to_pandas()
self.assertEqual(set(df['category'].tolist()), {'A'})
def test_read_paimon_with_limit(self):
"""``read_paimon(limit=N)`` propagates the limit into the scan plan.
Writes 10 rows across two partitions (5 + 5) so the scan produces two
raw-convertible splits. ``limit=3`` causes ``FileScanner`` to drop the
second split once the first already covers the limit, so the Ray
Dataset contains strictly fewer than the full 10 rows.
We assert ``< 10`` (not ``== N``) because Paimon's scan-time limit is
a per-split cap — whole-split granularity at this layer — not a
row-exact hard limit. Row-exact short-circuiting in the reader is a
separate follow-up.
"""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('part', pa.string()),
('value', pa.string()),
])
identifier = self._create_and_populate_table(
'test_read_limit', pa_schema,
{
'id': list(range(10)),
'part': ['a'] * 5 + ['b'] * 5,
'value': [str(i) for i in range(10)],
},
partition_keys=['part'],
)
# Sanity baseline: the full unbounded scan returns all 10 rows.
ds_full = read_paimon(identifier, self.catalog_options)
self.assertEqual(ds_full.count(), 10)
# With limit=3, the scan plan drops the second partition's split
# once the first split's row count already covers the limit.
ds = read_paimon(identifier, self.catalog_options, limit=3)
limited_count = ds.count()
self.assertGreater(limited_count, 0)
self.assertLess(limited_count, 10)
def test_read_paimon_empty_table(self):
"""read_paimon() on a table with no data returns an empty dataset."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([('id', pa.int32())])
identifier = 'default.test_read_empty'
catalog = CatalogFactory.create(self.catalog_options)
schema = Schema.from_pyarrow_schema(pa_schema)
catalog.create_table(identifier, schema, False)
ds = read_paimon(identifier, self.catalog_options)
self.assertEqual(ds.count(), 0)
def test_read_paimon_with_snapshot_id(self):
"""read_paimon(snapshot_id=N) time-travels to that snapshot."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([('id', pa.int32()), ('name', pa.string())])
identifier = 'default.test_read_snap_id'
catalog = CatalogFactory.create(self.catalog_options)
schema = Schema.from_pyarrow_schema(pa_schema)
catalog.create_table(identifier, schema, False)
table = catalog.get_table(identifier)
for batch in [{'id': [1], 'name': ['a']}, {'id': [2], 'name': ['b']}]:
wb = table.new_batch_write_builder()
writer = wb.new_write()
writer.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
wb.new_commit().commit(writer.prepare_commit())
writer.close()
ds_latest = read_paimon(identifier, self.catalog_options)
self.assertEqual(ds_latest.count(), 2)
ds_snap1 = read_paimon(identifier, self.catalog_options, snapshot_id=1)
self.assertEqual(ds_snap1.count(), 1)
self.assertEqual(ds_snap1.to_pandas()['id'].tolist(), [1])
def test_read_paimon_with_tag_name(self):
"""read_paimon(tag_name=...) time-travels to a tagged snapshot."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([('id', pa.int32()), ('name', pa.string())])
identifier = 'default.test_read_tag_name'
catalog = CatalogFactory.create(self.catalog_options)
schema = Schema.from_pyarrow_schema(pa_schema)
catalog.create_table(identifier, schema, False)
table = catalog.get_table(identifier)
wb = table.new_batch_write_builder()
writer = wb.new_write()
writer.write_arrow(pa.Table.from_pydict({'id': [1], 'name': ['a']}, schema=pa_schema))
wb.new_commit().commit(writer.prepare_commit())
writer.close()
table.create_tag('v1')
wb = table.new_batch_write_builder()
writer = wb.new_write()
writer.write_arrow(pa.Table.from_pydict({'id': [2], 'name': ['b']}, schema=pa_schema))
wb.new_commit().commit(writer.prepare_commit())
writer.close()
ds_tag = read_paimon(identifier, self.catalog_options, tag_name='v1')
self.assertEqual(ds_tag.count(), 1)
self.assertEqual(ds_tag.to_pandas()['id'].tolist(), [1])
def test_read_paimon_rejects_snapshot_id_and_tag_name_together(self):
from pypaimon.ray import read_paimon
with self.assertRaises(ValueError):
read_paimon(
'default.dummy', self.catalog_options,
snapshot_id=1, tag_name='v1',
)
def test_write_paimon_basic(self):
"""write_paimon() writes data that read_paimon() can round-trip."""
from pypaimon.ray import read_paimon, write_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
])
identifier = 'default.test_write_basic'
catalog = CatalogFactory.create(self.catalog_options)
schema = Schema.from_pyarrow_schema(pa_schema)
catalog.create_table(identifier, schema, False)
source = pa.Table.from_pydict(
{'id': [1, 2, 3], 'name': ['x', 'y', 'z']}, schema=pa_schema,
)
ds = ray.data.from_arrow(source)
write_paimon(ds, identifier, self.catalog_options)
result = read_paimon(identifier, self.catalog_options)
self.assertEqual(result.count(), 3)
df = result.to_pandas().sort_values('id').reset_index(drop=True)
self.assertEqual(list(df['name']), ['x', 'y', 'z'])
def test_write_paimon_overwrite(self):
"""write_paimon(overwrite=True) replaces existing data."""
from pypaimon.ray import read_paimon, write_paimon
pa_schema = pa.schema([
('id', pa.int32()),
('val', pa.int64()),
])
identifier = 'default.test_write_overwrite'
catalog = CatalogFactory.create(self.catalog_options)
schema = Schema.from_pyarrow_schema(pa_schema)
catalog.create_table(identifier, schema, False)
ds1 = ray.data.from_arrow(
pa.Table.from_pydict({'id': [1, 2], 'val': [10, 20]}, schema=pa_schema)
)
write_paimon(ds1, identifier, self.catalog_options)
ds2 = ray.data.from_arrow(
pa.Table.from_pydict({'id': [3], 'val': [30]}, schema=pa_schema)
)
write_paimon(ds2, identifier, self.catalog_options, overwrite=True)
result = read_paimon(identifier, self.catalog_options)
self.assertEqual(result.count(), 1)
df = result.to_pandas()
self.assertEqual(list(df['id']), [3])
def test_read_paimon_primary_key(self):
"""read_paimon() merges PK rows correctly after an upsert."""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
identifier = self._create_and_populate_table(
'test_read_pk', pa_schema,
{'id': [1, 2, 3], 'name': ['a', 'b', 'c']},
primary_keys=['id'],
options={'bucket': '2'},
)
catalog = CatalogFactory.create(self.catalog_options)
table = catalog.get_table(identifier)
update = pa.Table.from_pydict({'id': [1, 4], 'name': ['a2', 'd']}, schema=pa_schema)
wb = table.new_batch_write_builder()
w = wb.new_write()
w.write_arrow(update)
msgs = w.prepare_commit()
wb.new_commit().commit(msgs)
w.close()
ds = read_paimon(identifier, self.catalog_options)
self.assertEqual(ds.count(), 4)
df = ds.to_pandas().sort_values('id').reset_index(drop=True)
self.assertEqual(list(df['name']), ['a2', 'b', 'c', 'd'])
def test_read_paimon_invalid_override_num_blocks(self):
"""override_num_blocks below 1 is rejected with a clear error."""
from pypaimon.ray import read_paimon
with self.assertRaises(ValueError):
read_paimon('default.does_not_matter', self.catalog_options,
override_num_blocks=0)
def test_read_paimon_pk_single_snapshot(self):
"""read_paimon on a PK table with a single snapshot (raw-convertible
splits) must not raise ArrowInvalid on schema nullability mismatch.
The Paimon table schema marks PK columns as NOT NULL, but the
Parquet reader may produce nullable fields. The RayDatasource
read task must cast the batch to align the schema rather than
rejecting it via strict from_batches equality.
"""
from pypaimon.ray import read_paimon
pa_schema = pa.schema([
pa.field('id', pa.int32(), nullable=False),
('name', pa.string()),
])
identifier = self._create_and_populate_table(
'test_read_pk_single_snap', pa_schema,
{'id': [1, 2, 3], 'name': ['a', 'b', 'c']},
primary_keys=['id'], options={'bucket': '2'},
)
ds = read_paimon(identifier, self.catalog_options)
self.assertEqual(ds.count(), 3)
df = ds.to_pandas().sort_values('id').reset_index(drop=True)
self.assertEqual(list(df['id']), [1, 2, 3])
if __name__ == '__main__':
unittest.main()