blob: 69c1bdfc3f00adb61782b6874a2c2477ec29d3cd [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the manager-level helpers that back system tables.
The snapshots / schemas system tables need bulk-listing helpers that
the corresponding managers had not surfaced. This file pins down
their contracts:
* :meth:`SnapshotManager.list_snapshots` — enumerate persisted
snapshots in ID order, skipping IDs whose files have been expired.
* :meth:`SchemaManager.list_all` — return every committed table
schema in ID order.
"""
import os
import shutil
import tempfile
import time
import unittest
from pypaimon import CatalogFactory, Schema
from pypaimon.common.file_io import FileIO
from pypaimon.common.json_util import JSON
from pypaimon.schema.data_types import DataField
from pypaimon.schema.schema_change import SchemaChange
from pypaimon.snapshot.snapshot import Snapshot
from pypaimon.snapshot.snapshot_manager import SnapshotManager
def _write_snapshot(file_io: FileIO, snapshot_dir: str, snapshot_id: int,
schema_id: int = 0):
"""Write a minimal snapshot JSON to ``snapshot_dir``."""
snapshot = Snapshot(
version=3,
id=snapshot_id,
schema_id=schema_id,
base_manifest_list="base-{}.avro".format(snapshot_id),
delta_manifest_list="delta-{}.avro".format(snapshot_id),
total_record_count=0,
delta_record_count=0,
commit_user="test-user",
commit_identifier=snapshot_id,
commit_kind="APPEND",
time_millis=int(time.time() * 1000),
)
file_io.try_to_write_atomic(
"{}/snapshot-{}".format(snapshot_dir, snapshot_id),
JSON.to_json(snapshot),
)
def _new_warehouse():
tmp = tempfile.mkdtemp(prefix="mgrext_")
return tmp, os.path.join(tmp, "warehouse")
class SnapshotManagerListSnapshotsTest(unittest.TestCase):
def setUp(self):
self.tmp, self.warehouse = _new_warehouse()
self.catalog = CatalogFactory.create({"warehouse": self.warehouse})
self.catalog.create_database("db", False)
fields = [DataField.from_dict({"id": 0, "name": "v", "type": "INT"})]
self.catalog.create_table("db.t", Schema(fields=fields), False)
self.table = self.catalog.get_table("db.t")
self.file_io = self.table.file_io
self.snapshot_dir = "{}/snapshot".format(self.table.table_path)
self.file_io.mkdirs(self.snapshot_dir)
self.manager = SnapshotManager(
self.file_io, self.table.table_path, branch="main"
)
def tearDown(self):
shutil.rmtree(self.tmp, ignore_errors=True)
def test_empty_table_returns_empty_list(self):
self.assertEqual([], self.manager.list_snapshots())
def test_lists_persisted_snapshots_in_id_order(self):
for sid in (3, 1, 2): # write out of order to prove sort
_write_snapshot(self.file_io, self.snapshot_dir, sid)
result = self.manager.list_snapshots()
self.assertEqual([1, 2, 3], [s.id for s in result])
def test_skips_gaps_from_expired_snapshots(self):
# snapshot-2 missing — simulates an expired ID still bracketed by
# earlier and later live snapshots.
for sid in (1, 3):
_write_snapshot(self.file_io, self.snapshot_dir, sid)
result = self.manager.list_snapshots()
self.assertEqual([1, 3], [s.id for s in result])
class SchemaManagerListAllTest(unittest.TestCase):
def setUp(self):
self.tmp, self.warehouse = _new_warehouse()
self.catalog = CatalogFactory.create({"warehouse": self.warehouse})
self.catalog.create_database("db", False)
fields = [
DataField.from_dict({"id": 0, "name": "v", "type": "INT"}),
]
self.catalog.create_table("db.t", Schema(fields=fields), False)
self.manager = self.catalog.get_table("db.t").schema_manager
def tearDown(self):
shutil.rmtree(self.tmp, ignore_errors=True)
def test_returns_single_schema_initially(self):
result = self.manager.list_all()
self.assertEqual([0], [s.id for s in result])
def test_lists_every_committed_schema_in_id_order(self):
for new_value in ("a", "b", "c"):
self.manager.commit_changes([SchemaChange.set_option(
"user-tag", new_value)])
result = self.manager.list_all()
self.assertEqual([0, 1, 2, 3], [s.id for s in result])
# Cache shouldn't return stale objects across calls.
again = self.manager.list_all()
self.assertEqual([0, 1, 2, 3], [s.id for s in again])
if __name__ == "__main__":
unittest.main()