blob: c22346afe7391703102759a66d4c6b6f44b23bee [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import shutil
import tempfile
import time
import unittest
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.common.options.core_options import CoreOptions
from pypaimon.snapshot.snapshot_manager import SnapshotManager
class PkReaderTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({
'warehouse': cls.warehouse
})
cls.catalog.create_database('default', False)
cls.pa_schema = pa.schema([
pa.field('user_id', pa.int32(), nullable=False),
('item_id', pa.int64()),
('behavior', pa.string()),
pa.field('dt', pa.string(), nullable=False)
])
cls.expected = pa.Table.from_pydict({
'user_id': [1, 2, 3, 4, 5, 7, 8],
'item_id': [1001, 1002, 1003, 1004, 1005, 1007, 1008],
'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'],
'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
}, schema=cls.pa_schema)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def test_pk_parquet_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_pk_parquet', schema, False)
table = self.catalog.get_table('default.test_pk_parquet')
self._write_test_table(table)
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
self.assertEqual(actual, self.expected)
# Verify _VALUE_KIND field type is int8 in the written parquet file
table_scan = read_builder.new_scan()
splits = table_scan.plan().splits()
value_kind_field_found = False
for split in splits:
for file in split.files:
file_path = file.file_path
table_path = os.path.join(self.warehouse, 'default.db', 'test_pk_parquet')
full_path = os.path.join(table_path, file_path)
if os.path.exists(full_path) and file_path.endswith('.parquet'):
import pyarrow.parquet as pq
parquet_file = pq.ParquetFile(full_path)
# Use schema_arrow to get Arrow schema instead of ParquetSchema
file_schema = parquet_file.schema_arrow
for i in range(len(file_schema)):
field = file_schema.field(i)
if field.name == '_VALUE_KIND':
value_kind_field_found = True
self.assertEqual(
field.type, pa.int8(),
f"_VALUE_KIND field type should be int8, got {field.type}")
break
if value_kind_field_found:
break
if value_kind_field_found:
break
self.assertTrue(
value_kind_field_found,
"_VALUE_KIND field should exist in the written parquet file")
def test_pk_orc_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={
'bucket': '1',
'file.format': 'orc'
})
self.catalog.create_table('default.test_pk_orc', schema, False)
table = self.catalog.get_table('default.test_pk_orc')
self._write_test_table(table)
read_builder = table.new_read_builder()
actual: pa.Table = self._read_test_table(read_builder).sort_by('user_id')
# when bucket=1, actual field name will contain 'not null', so skip comparing field name
for i in range(len(actual.columns)):
col_a = actual.column(i)
col_b = self.expected.column(i)
self.assertEqual(col_a, col_b)
def test_pk_avro_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={
'bucket': '2',
'file.format': 'avro'
})
self.catalog.create_table('default.test_pk_avro', schema, False)
table = self.catalog.get_table('default.test_pk_avro')
self._write_test_table(table)
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
self.assertEqual(actual, self.expected)
def test_pk_lance_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={
'bucket': '2',
'file.format': 'lance'
})
self.catalog.create_table('default.test_pk_lance', schema, False)
table = self.catalog.get_table('default.test_pk_lance')
self._write_test_table(table)
read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
splits = table_scan.plan().splits()
for split in splits:
for file in split.files:
file_path = file.file_path
table_path = os.path.join(self.warehouse, 'default.db', 'test_pk_lance')
full_path = os.path.join(table_path, file_path)
if os.path.exists(full_path):
self.assertTrue(os.path.exists(full_path))
self.assertTrue(
file_path.endswith('.lance'),
f"Expected file path to end with .lance, got {file_path}")
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
self.assertEqual(actual, self.expected)
def test_pk_lance_reader_with_filter(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={
'bucket': '2',
'file.format': 'lance'
})
self.catalog.create_table('default.test_pk_lance_filter', schema, False)
table = self.catalog.get_table('default.test_pk_lance_filter')
self._write_test_table(table)
predicate_builder = table.new_read_builder().new_predicate_builder()
p1 = predicate_builder.is_in('dt', ['p1'])
p2 = predicate_builder.between('user_id', 2, 7)
p3 = predicate_builder.is_not_null('behavior')
g1 = predicate_builder.and_predicates([p1, p2, p3])
read_builder = table.new_read_builder().with_filter(g1)
actual = self._read_test_table(read_builder).sort_by('user_id')
expected = pa.concat_tables([
self.expected.slice(1, 1), # 2/b
self.expected.slice(5, 1) # 7/g
])
self.assertEqual(actual, expected)
def test_pk_multi_write_once_commit(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_pk_multi', schema, False)
table = self.catalog.get_table('default.test_pk_multi')
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
data1 = {
'user_id': [1, 2, 3, 4],
'item_id': [1001, 1002, 1003, 1004],
'behavior': ['a', 'b', 'c', None],
'dt': ['p1', 'p1', 'p2', 'p1'],
}
pa_table1 = pa.Table.from_pydict(data1, schema=self.pa_schema)
data2 = {
'user_id': [5, 2, 7, 8],
'item_id': [1005, 1002, 1007, 1008],
'behavior': ['e', 'b-new', 'g', 'h'],
'dt': ['p2', 'p1', 'p1', 'p2']
}
pa_table2 = pa.Table.from_pydict(data2, schema=self.pa_schema)
table_write.write_arrow(pa_table1)
table_write.write_arrow(pa_table2)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
# TODO support pk merge feature when multiple write
expected = pa.Table.from_pydict({
'user_id': [1, 2, 2, 3, 4, 5, 7, 8],
'item_id': [1001, 1002, 1002, 1003, 1004, 1005, 1007, 1008],
'behavior': ['a', 'b', 'b-new', 'c', None, 'e', 'g', 'h'],
'dt': ['p1', 'p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
}, schema=self.pa_schema)
self.assertEqual(actual, expected)
def test_pk_reader_with_filter(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_pk_filter', schema, False)
table = self.catalog.get_table('default.test_pk_filter')
self._write_test_table(table)
predicate_builder = table.new_read_builder().new_predicate_builder()
p1 = predicate_builder.is_in('dt', ['p1'])
p2 = predicate_builder.between('user_id', 2, 7)
p3 = predicate_builder.is_not_null('behavior')
g1 = predicate_builder.and_predicates([p1, p2, p3])
read_builder = table.new_read_builder().with_filter(g1)
actual = self._read_test_table(read_builder).sort_by('user_id')
expected = pa.concat_tables([
self.expected.slice(1, 1), # 2/b
self.expected.slice(5, 1) # 7/g
])
self.assertEqual(actual, expected)
def test_pk_reader_with_projection(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_pk_projection', schema, False)
table = self.catalog.get_table('default.test_pk_projection')
self._write_test_table(table)
read_builder = table.new_read_builder().with_projection(['dt', 'user_id', 'behavior'])
actual = self._read_test_table(read_builder).sort_by('user_id')
expected = self.expected.select(['dt', 'user_id', 'behavior'])
self.assertEqual(actual, expected)
def test_incremental_timestamp(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_incremental_parquet', schema, False)
table = self.catalog.get_table('default.test_incremental_parquet')
timestamp = int(time.time() * 1000)
self._write_test_table(table)
snapshot_manager = SnapshotManager(table)
t1 = snapshot_manager.get_snapshot_by_id(1).time_millis
t2 = snapshot_manager.get_snapshot_by_id(2).time_millis
# test 1
table = table.copy({CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key(): str(timestamp - 1) + ',' + str(timestamp)})
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder)
self.assertEqual(len(actual), 0)
# test 2
table = table.copy({CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key(): str(timestamp) + ',' + str(t2)})
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
self.assertEqual(self.expected, actual)
# test 3
table = table.copy({CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key(): str(t1) + ',' + str(t2)})
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
expected = pa.Table.from_pydict({
"user_id": [2, 5, 7, 8],
"item_id": [1002, 1005, 1007, 1008],
"behavior": ["b-new", "e", "g", "h"],
"dt": ["p1", "p2", "p1", "p2"]
}, schema=self.pa_schema)
self.assertEqual(expected, actual)
def test_incremental_read_multi_snapshots(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_incremental_read_multi_snapshots', schema, False)
table = self.catalog.get_table('default.test_incremental_read_multi_snapshots')
write_builder = table.new_batch_write_builder()
for i in range(1, 101):
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
pa_table = pa.Table.from_pydict({
'user_id': [i],
'item_id': [1000 + i],
'behavior': [f'snap{i}'],
'dt': ['p1' if i % 2 == 1 else 'p2'],
}, schema=self.pa_schema)
table_write.write_arrow(pa_table)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()
snapshot_manager = SnapshotManager(table)
t10 = snapshot_manager.get_snapshot_by_id(10).time_millis
t20 = snapshot_manager.get_snapshot_by_id(20).time_millis
table_inc = table.copy({CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key(): f"{t10},{t20}"})
read_builder = table_inc.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
expected = pa.Table.from_pydict({
'user_id': list(range(11, 21)),
'item_id': [1000 + i for i in range(11, 21)],
'behavior': [f'snap{i}' for i in range(11, 21)],
'dt': ['p1' if i % 2 == 1 else 'p2' for i in range(11, 21)],
}, schema=self.pa_schema).sort_by('user_id')
self.assertEqual(expected, actual)
def test_manifest_creation_time_timestamp(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table('default.test_manifest_creation_time', schema, False)
table = self.catalog.get_table('default.test_manifest_creation_time')
self._write_test_table(table)
snapshot_manager = SnapshotManager(table)
latest_snapshot = snapshot_manager.get_latest_snapshot()
read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
manifest_list_manager = table_scan.starting_scanner.manifest_list_manager
manifest_files = manifest_list_manager.read_all(latest_snapshot)
manifest_file_manager = table_scan.starting_scanner.manifest_file_manager
creation_times_found = []
for manifest_file_meta in manifest_files:
entries = manifest_file_manager.read(manifest_file_meta.file_name, drop_stats=False)
for entry in entries:
if entry.file.creation_time is not None:
creation_time = entry.file.creation_time
self.assertIsNotNone(creation_time)
epoch_millis = entry.file.creation_time_epoch_millis()
self.assertIsNotNone(epoch_millis)
self.assertGreater(epoch_millis, 0)
import time
expected_epoch_millis = creation_time.get_millisecond()
local_dt = creation_time.to_local_date_time()
local_time_struct = local_dt.timetuple()
local_timestamp = time.mktime(local_time_struct)
local_time_struct_utc = time.gmtime(local_timestamp)
utc_timestamp = time.mktime(local_time_struct_utc)
expected_epoch_millis = int(utc_timestamp * 1000)
self.assertEqual(epoch_millis, expected_epoch_millis)
creation_times_found.append(epoch_millis)
self.assertGreater(
len(creation_times_found), 0,
"At least one manifest entry should have creation_time")
def _write_test_table(self, table):
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
data1 = {
'user_id': [1, 2, 3, 4],
'item_id': [1001, 1002, 1003, 1004],
'behavior': ['a', 'b', 'c', None],
'dt': ['p1', 'p1', 'p2', 'p1'],
}
pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
table_write.write_arrow(pa_table)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
data1 = {
'user_id': [5, 2, 7, 8],
'item_id': [1005, 1002, 1007, 1008],
'behavior': ['e', 'b-new', 'g', 'h'],
'dt': ['p2', 'p1', 'p1', 'p2']
}
pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
table_write.write_arrow(pa_table)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()
def _read_test_table(self, read_builder):
table_read = read_builder.new_read()
splits = read_builder.new_scan().plan().splits()
return table_read.to_arrow(splits)
def test_concurrent_writes_with_retry(self):
"""Test concurrent writes to verify retry mechanism works correctly for PK tables."""
import threading
# Run the test 3 times to verify stability
iter_num = 3
for test_iteration in range(iter_num):
# Create a unique table for each iteration
table_name = f'default.test_pk_concurrent_writes_{test_iteration}'
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2'})
self.catalog.create_table(table_name, schema, False)
table = self.catalog.get_table(table_name)
write_results = []
write_errors = []
def write_data(thread_id, start_user_id):
"""Write data in a separate thread."""
try:
threading.current_thread().name = f"Iter{test_iteration}-Thread-{thread_id}"
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
# Create unique data for this thread
data = {
'user_id': list(range(start_user_id, start_user_id + 5)),
'item_id': [1000 + i for i in range(start_user_id, start_user_id + 5)],
'behavior': [f'thread{thread_id}_{i}' for i in range(5)],
'dt': ['p1' if i % 2 == 0 else 'p2' for i in range(5)],
}
pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
table_write.write_arrow(pa_table)
commit_messages = table_write.prepare_commit()
table_commit.commit(commit_messages)
table_write.close()
table_commit.close()
write_results.append({
'thread_id': thread_id,
'start_user_id': start_user_id,
'success': True
})
except Exception as e:
write_errors.append({
'thread_id': thread_id,
'error': str(e)
})
# Create and start multiple threads
threads = []
num_threads = 10
for i in range(num_threads):
thread = threading.Thread(
target=write_data,
args=(i, i * 10)
)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify all writes succeeded (retry mechanism should handle conflicts)
self.assertEqual(num_threads, len(write_results),
f"Iteration {test_iteration}: Expected {num_threads} successful writes, "
f"got {len(write_results)}. Errors: {write_errors}")
self.assertEqual(0, len(write_errors),
f"Iteration {test_iteration}: Expected no errors, but got: {write_errors}")
read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
# Verify data rows (PK table should have unique user_id+dt combinations)
self.assertEqual(num_threads * 5, actual.num_rows,
f"Iteration {test_iteration}: Expected {num_threads * 5} rows")
# Verify user_id
user_ids = actual.column('user_id').to_pylist()
expected_user_ids = []
for i in range(num_threads):
expected_user_ids.extend(range(i * 10, i * 10 + 5))
expected_user_ids.sort()
self.assertEqual(user_ids, expected_user_ids,
f"Iteration {test_iteration}: User IDs mismatch")
# Verify snapshot count (should have num_threads snapshots)
snapshot_manager = SnapshotManager(table)
latest_snapshot = snapshot_manager.get_latest_snapshot()
self.assertIsNotNone(latest_snapshot,
f"Iteration {test_iteration}: Latest snapshot should not be None")
self.assertEqual(latest_snapshot.id, num_threads,
f"Iteration {test_iteration}: Expected snapshot ID {num_threads}, "
f"got {latest_snapshot.id}")
print(f"✓ PK Table Iteration {test_iteration + 1}/{iter_num} completed successfully")