| """ |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| """ |
| |
| import pyarrow as pa |
| |
| from pypaimon import Schema |
| from pypaimon.catalog.catalog_exception import DatabaseAlreadyExistException, TableAlreadyExistException, \ |
| DatabaseNotExistException, TableNotExistException |
| from pypaimon.schema.data_types import AtomicType, DataField |
| from pypaimon.schema.schema_change import SchemaChange |
| from pypaimon.tests.rest.rest_base_test import RESTBaseTest |
| from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor, DynamicBucketRowKeyExtractor, \ |
| UnawareBucketRowKeyExtractor |
| |
| |
| class RESTSimpleTest(RESTBaseTest): |
| def setUp(self): |
| super().setUp() |
| self.pa_schema = pa.schema([ |
| ('user_id', pa.int64()), |
| ('item_id', pa.int64()), |
| ('behavior', pa.string()), |
| ('dt', pa.string()), |
| ]) |
| self.data = { |
| 'user_id': [2, 4, 6, 8, 10], |
| 'item_id': [1001, 1002, 1003, 1004, 1005], |
| 'behavior': ['a', 'b', 'c', 'd', 'e'], |
| 'dt': ['2000-10-10', '2025-08-10', '2025-08-11', '2025-08-12', '2025-08-13'] |
| } |
| self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema) |
| |
| def test_with_shard_ao_unaware_bucket(self): |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket') |
| write_builder = table.new_batch_write_builder() |
| # first write |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data1 = { |
| 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], |
| 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], |
| 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| # second write |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data2 = { |
| 'user_id': [5, 6, 7, 8, 18], |
| 'item_id': [1005, 1006, 1007, 1008, 1018], |
| 'behavior': ['e', 'f', 'g', 'h', 'z'], |
| 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| splits = read_builder.new_scan().with_shard(2, 3).plan().splits() |
| actual = table_read.to_arrow(splits).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [5, 7, 8, 9, 11, 13], |
| 'item_id': [1005, 1007, 1008, 1009, 1011, 1013], |
| 'behavior': ['e', 'g', 'h', 'h', 'j', 'l'], |
| 'dt': ['p2', 'p2', 'p2', 'p2', 'p2', 'p2'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| # Get the three actual tables |
| splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() |
| actual1 = table_read.to_arrow(splits1).sort_by('user_id') |
| splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() |
| actual2 = table_read.to_arrow(splits2).sort_by('user_id') |
| splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() |
| actual3 = table_read.to_arrow(splits3).sort_by('user_id') |
| |
| # Concatenate the three tables |
| actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('user_id') |
| expected = self._read_test_table(read_builder).sort_by('user_id') |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_ao_unaware_bucket_manual(self): |
| """Test shard_ao_unaware_bucket with setting bucket -1 manually""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], |
| options={'bucket': '-1'}) |
| self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket_manual', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket_manual') |
| write_builder = table.new_batch_write_builder() |
| |
| # Write data with single partition |
| table_write = write_builder.new_write() |
| self.assertIsInstance(table_write.row_key_extractor, UnawareBucketRowKeyExtractor) |
| |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5, 6], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], |
| 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], |
| 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test first shard (0, 2) - should get first 3 rows |
| plan = read_builder.new_scan().with_shard(0, 2).plan() |
| actual = table_read.to_arrow(plan.splits()).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [1, 2, 3], |
| 'item_id': [1001, 1002, 1003], |
| 'behavior': ['a', 'b', 'c'], |
| 'dt': ['p1', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| # Test second shard (1, 2) - should get last 3 rows |
| plan = read_builder.new_scan().with_shard(1, 2).plan() |
| actual = table_read.to_arrow(plan.splits()).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [4, 5, 6], |
| 'item_id': [1004, 1005, 1006], |
| 'behavior': ['d', 'e', 'f'], |
| 'dt': ['p1', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_ao_fixed_bucket(self): |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], |
| options={'bucket': '5', 'bucket-key': 'item_id'}) |
| self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket') |
| write_builder = table.new_batch_write_builder() |
| # first write |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data1 = { |
| 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], |
| 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], |
| 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| # second write |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data2 = { |
| 'user_id': [5, 6, 7, 8], |
| 'item_id': [1005, 1006, 1007, 1008], |
| 'behavior': ['e', 'f', 'g', 'h'], |
| 'dt': ['p2', 'p1', 'p2', 'p2'], |
| } |
| pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| splits = read_builder.new_scan().with_shard(0, 3).plan().splits() |
| actual = table_read.to_arrow(splits).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [1, 2, 3, 5, 8, 12], |
| 'item_id': [1001, 1002, 1003, 1005, 1008, 1012], |
| 'behavior': ['a', 'b', 'c', 'd', 'g', 'k'], |
| 'dt': ['p1', 'p1', 'p2', 'p2', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| # Get the three actual tables |
| splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() |
| actual1 = table_read.to_arrow(splits1).sort_by('user_id') |
| splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() |
| actual2 = table_read.to_arrow(splits2).sort_by('user_id') |
| splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() |
| actual3 = table_read.to_arrow(splits3).sort_by('user_id') |
| |
| # Concatenate the three tables |
| actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('user_id') |
| expected = self._read_test_table(read_builder).sort_by('user_id') |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_single_partition(self): |
| """Test sharding with single partition - tests _filter_by_shard with simple data""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_single_partition', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_single_partition') |
| write_builder = table.new_batch_write_builder() |
| |
| # Write data with single partition |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5, 6], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], |
| 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], |
| 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test first shard (0, 2) - should get first 3 rows |
| plan = read_builder.new_scan().with_shard(0, 2).plan() |
| actual = table_read.to_arrow(plan.splits()).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [1, 2, 3], |
| 'item_id': [1001, 1002, 1003], |
| 'behavior': ['a', 'b', 'c'], |
| 'dt': ['p1', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| # Test second shard (1, 2) - should get last 3 rows |
| plan = read_builder.new_scan().with_shard(1, 2).plan() |
| actual = table_read.to_arrow(plan.splits()).sort_by('user_id') |
| expected = pa.Table.from_pydict({ |
| 'user_id': [4, 5, 6], |
| 'item_id': [1004, 1005, 1006], |
| 'behavior': ['d', 'e', 'f'], |
| 'dt': ['p1', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_uneven_distribution(self): |
| """Test sharding with uneven row distribution across shards""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_uneven', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_uneven') |
| write_builder = table.new_batch_write_builder() |
| |
| # Write data with 7 rows (not evenly divisible by 3) |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5, 6, 7], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007], |
| 'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g'], |
| 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test sharding into 3 parts: 3, 2, 2 rows |
| plan1 = read_builder.new_scan().with_shard(0, 3).plan() |
| actual1 = table_read.to_arrow(plan1.splits()).sort_by('user_id') |
| expected1 = pa.Table.from_pydict({ |
| 'user_id': [1, 2, 3], |
| 'item_id': [1001, 1002, 1003], |
| 'behavior': ['a', 'b', 'c'], |
| 'dt': ['p1', 'p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual1, expected1) |
| |
| plan2 = read_builder.new_scan().with_shard(1, 3).plan() |
| actual2 = table_read.to_arrow(plan2.splits()).sort_by('user_id') |
| expected2 = pa.Table.from_pydict({ |
| 'user_id': [4, 5], |
| 'item_id': [1004, 1005], |
| 'behavior': ['d', 'e'], |
| 'dt': ['p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual2, expected2) |
| |
| plan3 = read_builder.new_scan().with_shard(2, 3).plan() |
| actual3 = table_read.to_arrow(plan3.splits()).sort_by('user_id') |
| expected3 = pa.Table.from_pydict({ |
| 'user_id': [6, 7], |
| 'item_id': [1006, 1007], |
| 'behavior': ['f', 'g'], |
| 'dt': ['p1', 'p1'], |
| }, schema=self.pa_schema) |
| self.assertEqual(actual3, expected3) |
| |
| def test_with_shard_single_shard(self): |
| """Test sharding with only one shard - should return all data""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_single', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_single') |
| write_builder = table.new_batch_write_builder() |
| |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4], |
| 'item_id': [1001, 1002, 1003, 1004], |
| 'behavior': ['a', 'b', 'c', 'd'], |
| 'dt': ['p1', 'p1', 'p2', 'p2'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test single shard (0, 1) - should get all data |
| plan = read_builder.new_scan().with_shard(0, 1).plan() |
| actual = table_read.to_arrow(plan.splits()).sort_by('user_id') |
| expected = pa.Table.from_pydict(data, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_many_small_shards(self): |
| """Test sharding with many small shards""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_many_small', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_many_small') |
| write_builder = table.new_batch_write_builder() |
| |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5, 6], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], |
| 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], |
| 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test with 6 shards (one row per shard) |
| for i in range(6): |
| plan = read_builder.new_scan().with_shard(i, 6).plan() |
| actual = table_read.to_arrow(plan.splits()) |
| self.assertEqual(len(actual), 1) |
| self.assertEqual(actual['user_id'][0].as_py(), i + 1) |
| |
| def test_with_shard_boundary_conditions(self): |
| """Test sharding boundary conditions with edge cases""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_boundary', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_boundary') |
| write_builder = table.new_batch_write_builder() |
| |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5], |
| 'item_id': [1001, 1002, 1003, 1004, 1005], |
| 'behavior': ['a', 'b', 'c', 'd', 'e'], |
| 'dt': ['p1', 'p1', 'p1', 'p1', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test first shard (0, 4) - should get 1 row (5//4 +1= 2) |
| plan = read_builder.new_scan().with_shard(0, 4).plan() |
| actual = table_read.to_arrow(plan.splits()) |
| self.assertEqual(len(actual), 2) |
| |
| # Test middle shard (1, 4) - should get 1 row |
| plan = read_builder.new_scan().with_shard(1, 4).plan() |
| actual = table_read.to_arrow(plan.splits()) |
| self.assertEqual(len(actual), 1) |
| |
| # Test last shard (3, 4) - should get 1 rows (remainder goes to last shard) |
| plan = read_builder.new_scan().with_shard(3, 4).plan() |
| actual = table_read.to_arrow(plan.splits()) |
| self.assertEqual(len(actual), 1) |
| |
| def test_with_shard_large_dataset(self): |
| """Test with_shard method using 50000 rows of data to verify performance and correctness""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], |
| options={'bucket': '5', 'bucket-key': 'item_id'}) |
| self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard_large_dataset') |
| write_builder = table.new_batch_write_builder() |
| |
| # Generate 50000 rows of test data |
| num_rows = 50000 |
| batch_size = 5000 # Write in batches to avoid memory issues |
| |
| for batch_start in range(0, num_rows, batch_size): |
| batch_end = min(batch_start + batch_size, num_rows) |
| batch_data = { |
| 'user_id': list(range(batch_start + 1, batch_end + 1)), |
| 'item_id': [2000 + i for i in range(batch_start, batch_end)], |
| 'behavior': [chr(ord('a') + (i % 26)) for i in range(batch_start, batch_end)], |
| 'dt': [f'p{(i % 5) + 1}' for i in range(batch_start, batch_end)], |
| } |
| |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| pa_table = pa.Table.from_pydict(batch_data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Test with 6 shards |
| num_shards = 6 |
| shard_results = [] |
| total_rows_from_shards = 0 |
| |
| for shard_idx in range(num_shards): |
| splits = read_builder.new_scan().with_shard(shard_idx, num_shards).plan().splits() |
| shard_result = table_read.to_arrow(splits) |
| shard_results.append(shard_result) |
| shard_rows = len(shard_result) if shard_result else 0 |
| total_rows_from_shards += shard_rows |
| print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows") |
| |
| # Verify that all shards together contain all the data |
| concatenated_result = pa.concat_tables(shard_results).sort_by('user_id') |
| |
| # Read all data without sharding for comparison |
| all_splits = read_builder.new_scan().plan().splits() |
| all_data = table_read.to_arrow(all_splits).sort_by('user_id') |
| |
| # Verify total row count |
| self.assertEqual(len(concatenated_result), len(all_data)) |
| self.assertEqual(len(all_data), num_rows) |
| self.assertEqual(total_rows_from_shards, num_rows) |
| |
| # Verify data integrity - check first and last few rows |
| self.assertEqual(concatenated_result['user_id'][0].as_py(), 1) |
| self.assertEqual(concatenated_result['user_id'][-1].as_py(), num_rows) |
| self.assertEqual(concatenated_result['item_id'][0].as_py(), 2000) |
| self.assertEqual(concatenated_result['item_id'][-1].as_py(), 2000 + num_rows - 1) |
| |
| # Verify that concatenated result equals all data |
| self.assertEqual(concatenated_result, all_data) |
| # Test with different shard configurations |
| # Test with 10 shards |
| shard_10_results = [] |
| for shard_idx in range(10): |
| splits = read_builder.new_scan().with_shard(shard_idx, 10).plan().splits() |
| shard_result = table_read.to_arrow(splits) |
| if shard_result: |
| shard_10_results.append(shard_result) |
| |
| if shard_10_results: |
| concatenated_10_shards = pa.concat_tables(shard_10_results).sort_by('user_id') |
| self.assertEqual(len(concatenated_10_shards), num_rows) |
| self.assertEqual(concatenated_10_shards, all_data) |
| |
| # Test with single shard (should return all data) |
| single_shard_splits = read_builder.new_scan().with_shard(0, 1).plan().splits() |
| single_shard_result = table_read.to_arrow(single_shard_splits).sort_by('user_id') |
| self.assertEqual(len(single_shard_result), num_rows) |
| self.assertEqual(single_shard_result, all_data) |
| |
| print(f"Successfully tested with_shard method using {num_rows} rows of data") |
| |
| def test_with_shard_large_dataset_one_commit(self): |
| """Test with_shard method using 50000 rows of data to verify performance and correctness""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema) |
| self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard_large_dataset') |
| write_builder = table.new_batch_write_builder() |
| |
| # Generate 50000 rows of test data |
| num_rows = 50000 |
| batch_data = { |
| 'user_id': list(range(0, num_rows)), |
| 'item_id': [2000 + i for i in range(0, num_rows)], |
| 'behavior': [chr(ord('a') + (i % 26)) for i in range(0, num_rows)], |
| 'dt': [f'p{(i % 5) + 1}' for i in range(0, num_rows)], |
| } |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| pa_table = pa.Table.from_pydict(batch_data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| num_shards = 5 |
| shard_results = [] |
| total_rows_from_shards = 0 |
| for shard_idx in range(num_shards): |
| splits = read_builder.new_scan().with_shard(shard_idx, num_shards).plan().splits() |
| shard_result = table_read.to_arrow(splits) |
| shard_results.append(shard_result) |
| shard_rows = len(shard_result) if shard_result else 0 |
| total_rows_from_shards += shard_rows |
| print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows") |
| |
| # Verify that all shards together contain all the data |
| concatenated_result = pa.concat_tables(shard_results).sort_by('user_id') |
| |
| # Read all data without sharding for comparison |
| all_splits = read_builder.new_scan().plan().splits() |
| all_data = table_read.to_arrow(all_splits).sort_by('user_id') |
| |
| # Verify total row count |
| self.assertEqual(len(concatenated_result), len(all_data)) |
| self.assertEqual(len(all_data), num_rows) |
| self.assertEqual(total_rows_from_shards, num_rows) |
| |
| # Verify data integrity - check first and last few rows |
| self.assertEqual(concatenated_result['user_id'][0].as_py(), 0) |
| self.assertEqual(concatenated_result['user_id'][-1].as_py(), num_rows - 1) |
| self.assertEqual(concatenated_result['item_id'][0].as_py(), 2000) |
| self.assertEqual(concatenated_result['item_id'][-1].as_py(), 2000 + num_rows - 1) |
| |
| # Verify that concatenated result equals all data |
| self.assertEqual(concatenated_result, all_data) |
| |
| def test_with_shard_parameter_validation(self): |
| """Test edge cases for parameter validation""" |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.test_shard_validation_edge', schema, False) |
| table = self.rest_catalog.get_table('default.test_shard_validation_edge') |
| |
| read_builder = table.new_read_builder() |
| # Test invalid case with number_of_para_subtasks = 1 |
| with self.assertRaises(Exception) as context: |
| read_builder.new_scan().with_shard(1, 1).plan() |
| self.assertEqual(str(context.exception), "idx_of_this_subtask must be less than number_of_para_subtasks") |
| |
| def test_with_shard_pk_dynamic_bucket(self): |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id'], primary_keys=['user_id', 'dt']) |
| self.rest_catalog.create_table('default.test_with_shard', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard') |
| |
| write_builder = table.new_batch_write_builder() |
| table_write = write_builder.new_write() |
| self.assertIsInstance(table_write.row_key_extractor, DynamicBucketRowKeyExtractor) |
| |
| pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema) |
| |
| with self.assertRaises(ValueError) as context: |
| table_write.write_arrow(pa_table) |
| |
| self.assertEqual(str(context.exception), "Can't extract bucket from row in dynamic bucket mode") |
| |
| def test_with_shard_pk_fixed_bucket(self): |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id'], primary_keys=['user_id', 'dt'], |
| options={'bucket': '5'}) |
| self.rest_catalog.create_table('default.test_with_shard', schema, False) |
| table = self.rest_catalog.get_table('default.test_with_shard') |
| |
| write_builder = table.new_batch_write_builder() |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| self.assertIsInstance(table_write.row_key_extractor, FixedBucketRowKeyExtractor) |
| |
| pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| splits = [] |
| read_builder = table.new_read_builder() |
| splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits()) |
| splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits()) |
| splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits()) |
| |
| table_read = read_builder.new_read() |
| actual = table_read.to_arrow(splits) |
| data_expected = { |
| 'user_id': [4, 6, 2, 10, 8], |
| 'item_id': [1002, 1003, 1001, 1005, 1004], |
| 'behavior': ['b', 'c', 'a', 'e', 'd'], |
| 'dt': ['2025-08-10', '2025-08-11', '2000-10-10', '2025-08-13', '2025-08-12'] |
| } |
| expected = pa.Table.from_pydict(data_expected, schema=self.pa_schema) |
| self.assertEqual(actual, expected) |
| |
| def test_with_shard_uniform_division(self): |
| schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) |
| self.rest_catalog.create_table('default.with_shard_uniform_division', schema, False) |
| table = self.rest_catalog.get_table('default.with_shard_uniform_division') |
| write_builder = table.new_batch_write_builder() |
| table_write = write_builder.new_write() |
| table_commit = write_builder.new_commit() |
| data = { |
| 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], |
| 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], |
| 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], |
| 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], |
| } |
| pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) |
| table_write.write_arrow(pa_table) |
| table_commit.commit(table_write.prepare_commit()) |
| table_write.close() |
| table_commit.close() |
| |
| read_builder = table.new_read_builder() |
| table_read = read_builder.new_read() |
| |
| # Get the three actual tables |
| splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() |
| actual1 = table_read.to_arrow(splits1).sort_by('user_id') |
| splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() |
| actual2 = table_read.to_arrow(splits2).sort_by('user_id') |
| splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() |
| actual3 = table_read.to_arrow(splits3).sort_by('user_id') |
| self.assertEqual(5, len(actual1)) |
| self.assertEqual(5, len(actual2)) |
| self.assertEqual(4, len(actual3)) |
| # Concatenate the three tables |
| actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('user_id') |
| expected = self._read_test_table(read_builder).sort_by('user_id') |
| self.assertEqual(expected, actual) |
| |
| def test_create_drop_database_table(self): |
| # test create database |
| self.rest_catalog.create_database("db1", False) |
| |
| with self.assertRaises(DatabaseAlreadyExistException) as context: |
| self.rest_catalog.create_database("db1", False) |
| |
| self.assertEqual("db1", context.exception.database) |
| |
| try: |
| self.rest_catalog.create_database("db1", True) |
| except DatabaseAlreadyExistException: |
| self.fail("create_database with ignore_if_exists=True should not raise DatabaseAlreadyExistException") |
| |
| # test create table |
| self.rest_catalog.create_table("db1.tbl1", |
| Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']), |
| False) |
| with self.assertRaises(TableAlreadyExistException) as context: |
| self.rest_catalog.create_table("db1.tbl1", |
| Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']), |
| False) |
| self.assertEqual("db1.tbl1", context.exception.identifier.get_full_name()) |
| |
| try: |
| self.rest_catalog.create_table("db1.tbl1", |
| Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']), |
| True) |
| except TableAlreadyExistException: |
| self.fail("create_table with ignore_if_exists=True should not raise TableAlreadyExistException") |
| |
| # test drop table |
| self.rest_catalog.drop_table("db1.tbl1", False) |
| with self.assertRaises(TableNotExistException) as context: |
| self.rest_catalog.drop_table("db1.tbl1", False) |
| self.assertEqual("db1.tbl1", context.exception.identifier.get_full_name()) |
| |
| try: |
| self.rest_catalog.drop_table("db1.tbl1", True) |
| except TableNotExistException: |
| self.fail("drop_table with ignore_if_not_exists=True should not raise TableNotExistException") |
| |
| # test drop database |
| self.rest_catalog.drop_database("db1", False) |
| with self.assertRaises(DatabaseNotExistException) as context: |
| self.rest_catalog.drop_database("db1", False) |
| self.assertEqual("db1", context.exception.database) |
| |
| try: |
| self.rest_catalog.drop_database("db1", True) |
| except DatabaseNotExistException: |
| self.fail("drop_database with ignore_if_not_exists=True should not raise DatabaseNotExistException") |
| |
| def test_alter_table(self): |
| catalog = self.rest_catalog |
| catalog.create_database("test_db_alter", True) |
| |
| identifier = "test_db_alter.test_table" |
| schema = Schema( |
| fields=[ |
| DataField.from_dict({"id": 0, "name": "col1", "type": "STRING", "description": "field1"}), |
| DataField.from_dict({"id": 1, "name": "col2", "type": "STRING", "description": "field2"}) |
| ], |
| partition_keys=[], |
| primary_keys=[], |
| options={}, |
| comment="comment" |
| ) |
| catalog.create_table(identifier, schema, False) |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.add_column("col3", AtomicType("DATE"))], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(len(table.fields), 3) |
| self.assertEqual(table.fields[2].name, "col3") |
| self.assertEqual(table.fields[2].type.type, "DATE") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.update_comment("new comment")], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(table.table_schema.comment, "new comment") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.rename_column("col1", "new_col1")], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(table.fields[0].name, "new_col1") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.update_column_type("col2", AtomicType("BIGINT"))], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(table.fields[1].type.type, "BIGINT") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.update_column_comment("col2", "col2 field")], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(table.fields[1].description, "col2 field") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.set_option("write-buffer-size", "256 MB")], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertEqual(table.table_schema.options.get("write-buffer-size"), "256 MB") |
| |
| catalog.alter_table( |
| identifier, |
| [SchemaChange.remove_option("write-buffer-size")], |
| False |
| ) |
| table = catalog.get_table(identifier) |
| self.assertNotIn("write-buffer-size", table.table_schema.options) |
| |
| with self.assertRaises(TableNotExistException): |
| catalog.alter_table( |
| "test_db_alter.non_existing_table", |
| [SchemaChange.add_column("col2", AtomicType("INT"))], |
| False |
| ) |
| |
| catalog.alter_table( |
| "test_db_alter.non_existing_table", |
| [SchemaChange.add_column("col2", AtomicType("INT"))], |
| True |
| ) |