blob: a9419825a7f511feed15fc266238ee44447ee343 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import shutil
import tempfile
import unittest
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.branch.branch_manager import BranchManager, DEFAULT_MAIN_BRANCH
from pypaimon.branch.filesystem_branch_manager import FileSystemBranchManager
from pypaimon.common.identifier import Identifier
class BranchManagerTest(unittest.TestCase):
"""Test BranchManager static methods."""
def test_is_main_branch(self):
"""Test is_main_branch method."""
self.assertTrue(BranchManager.is_main_branch("main"))
self.assertFalse(BranchManager.is_main_branch("feature"))
self.assertFalse(BranchManager.is_main_branch("develop"))
def test_normalize_branch(self):
"""Test normalize_branch method."""
self.assertEqual(BranchManager.normalize_branch(None), DEFAULT_MAIN_BRANCH)
self.assertEqual(BranchManager.normalize_branch(""), DEFAULT_MAIN_BRANCH)
self.assertEqual(BranchManager.normalize_branch(" "), DEFAULT_MAIN_BRANCH)
self.assertEqual(BranchManager.normalize_branch("main"), "main")
self.assertEqual(BranchManager.normalize_branch("feature"), "feature")
self.assertEqual(BranchManager.normalize_branch(" feature "), "feature")
def test_branch_path(self):
"""Test branch_path method."""
table_path = "/path/to/table"
self.assertEqual(BranchManager.branch_path(table_path, "main"), table_path)
self.assertEqual(
BranchManager.branch_path(table_path, "feature"),
"/path/to/table/branch/branch-feature"
)
self.assertEqual(
BranchManager.branch_path(table_path, "develop"),
"/path/to/table/branch/branch-develop"
)
def test_validate_branch_main_branch(self):
"""Test validate_branch rejects main branch."""
with self.assertRaises(ValueError) as context:
BranchManager.validate_branch("main")
self.assertIn("default branch", str(context.exception))
def test_validate_branch_blank(self):
"""Test validate_branch rejects blank branch names."""
with self.assertRaises(ValueError) as context:
BranchManager.validate_branch("")
self.assertIn("blank", str(context.exception))
with self.assertRaises(ValueError) as context:
BranchManager.validate_branch(" ")
self.assertIn("blank", str(context.exception))
def test_validate_branch_numeric(self):
"""Test validate_branch rejects pure numeric branch names."""
with self.assertRaises(ValueError) as context:
BranchManager.validate_branch("123")
self.assertIn("pure numeric", str(context.exception))
def test_validate_branch_valid(self):
"""Test validate_branch accepts valid branch names."""
# Should not raise exception
BranchManager.validate_branch("feature")
BranchManager.validate_branch("develop")
BranchManager.validate_branch("feature-branch-123")
def test_fast_forward_validate_to_main(self):
"""Test fast_forward_validate rejects fast-forward to main."""
with self.assertRaises(ValueError) as context:
BranchManager.fast_forward_validate("main", "feature")
self.assertIn("do not use in fast-forward", str(context.exception))
def test_fast_forward_validate_blank(self):
"""Test fast_forward_validate rejects blank branch name."""
with self.assertRaises(ValueError) as context:
BranchManager.fast_forward_validate("", "feature")
self.assertIn("blank", str(context.exception))
def test_fast_forward_validate_same_branch(self):
"""Test fast_forward_validate rejects fast-forward to same branch."""
with self.assertRaises(ValueError) as context:
BranchManager.fast_forward_validate("feature", "feature")
self.assertIn("from the current branch", str(context.exception))
class FileSystemBranchManagerTest(unittest.TestCase):
"""Test FileSystemBranchManager."""
def setUp(self):
"""Set up test environment."""
self.temp_dir = tempfile.mkdtemp()
self.table_path = os.path.join(self.temp_dir, "test_table")
from pypaimon.filesystem.local_file_io import LocalFileIO
self.file_io = LocalFileIO(self.table_path)
def tearDown(self):
"""Clean up test environment."""
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
def test_branch_directory(self):
"""Test branch_directory method."""
# Create a mock manager with minimal dependencies
manager = FileSystemBranchManager(
file_io=self.file_io,
table_path=self.table_path,
snapshot_manager=None,
tag_manager=None,
schema_manager=None
)
expected_dir = f"{self.table_path}/branch"
self.assertEqual(manager._branch_directory(), expected_dir)
def test_branch_path(self):
"""Test branch_path method."""
manager = FileSystemBranchManager(
file_io=self.file_io,
table_path=self.table_path,
snapshot_manager=None,
tag_manager=None,
schema_manager=None
)
self.assertEqual(manager.branch_path("main"), self.table_path)
self.assertEqual(
manager.branch_path("feature"),
f"{self.table_path}/branch/branch-feature"
)
def test_branches_empty(self):
"""Test branches returns empty list when no branches exist."""
manager = FileSystemBranchManager(
file_io=self.file_io,
table_path=self.table_path,
snapshot_manager=None,
tag_manager=None,
schema_manager=None
)
branches = manager.branches()
self.assertEqual(branches, [])
def test_branch_exists(self):
"""Test branch_exists method."""
manager = FileSystemBranchManager(
file_io=self.file_io,
table_path=self.table_path,
snapshot_manager=None,
tag_manager=None,
schema_manager=None
)
# Create a branch directory
branch_dir = os.path.join(self.table_path, "branch", "branch-feature")
os.makedirs(branch_dir, exist_ok=True)
self.assertTrue(manager.branch_exists("feature"))
self.assertFalse(manager.branch_exists("develop"))
class SnapshotManagerBranchAwarenessTest(unittest.TestCase):
"""Pin down SnapshotManager's branch-aware path computation.
Mirrors Java ``SnapshotManager.copyWithBranch`` (utils/SnapshotManager.java):
a non-main branch must produce paths under
``{table_path}/branch/branch-{name}/snapshot/...``.
"""
def setUp(self):
self.temp_dir = tempfile.mkdtemp(prefix="unittest_snapshot_branch_")
warehouse = os.path.join(self.temp_dir, "warehouse")
os.makedirs(warehouse, exist_ok=True)
self.catalog = CatalogFactory.create({"warehouse": warehouse})
self.catalog.create_database("default", True)
self.pa_schema = pa.schema([("id", pa.int64()), ("value", pa.string())])
self.identifier = Identifier.from_string("default.snapshot_branch_table")
self.catalog.create_table(
self.identifier, Schema.from_pyarrow_schema(self.pa_schema), False)
self.table = self.catalog.get_table(self.identifier)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_default_constructor_targets_main_branch(self):
sm = self.table.snapshot_manager()
self.assertEqual(sm.branch, "main")
self.assertEqual(sm.snapshot_dir, f"{self.table.table_path.rstrip('/')}/snapshot")
def test_copy_with_branch_returns_branch_aware_paths(self):
sm = self.table.snapshot_manager()
branch_sm = sm.copy_with_branch("b1")
self.assertIsNot(branch_sm, sm)
self.assertEqual(branch_sm.branch, "b1")
expected_dir = f"{self.table.table_path.rstrip('/')}/branch/branch-b1/snapshot"
self.assertEqual(branch_sm.snapshot_dir, expected_dir)
self.assertEqual(branch_sm.get_snapshot_path(7), f"{expected_dir}/snapshot-7")
self.assertNotEqual(sm.get_snapshot_path(7), branch_sm.get_snapshot_path(7))
def test_copy_with_branch_rebranches_snapshot_loader(self):
from pypaimon.common.identifier import Identifier
from pypaimon.snapshot.snapshot_loader import SnapshotLoader
sm = self.table.snapshot_manager()
# FileSystem path has no loader; inject one to exercise the
# rebranch code path. Mirrors Java SnapshotLoaderImpl.copyWithBranch.
sm.snapshot_loader = SnapshotLoader(
catalog_loader=object(),
identifier=Identifier(database="default",
object="snapshot_branch_table"),
)
branch_sm = sm.copy_with_branch("b1")
self.assertIsNotNone(branch_sm.snapshot_loader)
self.assertIsNot(branch_sm.snapshot_loader, sm.snapshot_loader)
self.assertEqual(branch_sm.snapshot_loader.identifier.branch, "b1")
self.assertEqual(
branch_sm.snapshot_loader.identifier.database,
sm.snapshot_loader.identifier.database)
self.assertEqual(
branch_sm.snapshot_loader.identifier.get_table_name(),
sm.snapshot_loader.identifier.get_table_name())
# Original loader's identifier untouched.
self.assertIsNone(sm.snapshot_loader.identifier.branch)
class FileSystemBranchManagerEndToEndTest(unittest.TestCase):
"""Catalog-driven end-to-end tests that exercise ``_copy_with_branch``.
These regress the from-tag and fast-forward paths that previously
raised ``SameFileError``: the SnapshotManager produced by
``_copy_with_branch`` still pointed at the main-branch snapshot dir,
so ``copy_file(src, dst)`` collapsed to ``src == dst``.
"""
def setUp(self):
self.temp_dir = tempfile.mkdtemp(prefix="unittest_fs_branch_e2e_")
warehouse = os.path.join(self.temp_dir, "warehouse")
os.makedirs(warehouse, exist_ok=True)
self.catalog = CatalogFactory.create({"warehouse": warehouse})
self.catalog.create_database("default", True)
self.pa_schema = pa.schema([("id", pa.int64()), ("value", pa.string())])
self.identifier = Identifier.from_string("default.fs_branch_e2e_table")
self.catalog.create_table(
self.identifier, Schema.from_pyarrow_schema(self.pa_schema), False)
table = self.catalog.get_table(self.identifier)
wb = table.new_batch_write_builder()
w = wb.new_write()
w.write_arrow(pa.Table.from_pydict(
{"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=self.pa_schema))
wb.new_commit().commit(w.prepare_commit())
w.close()
self.table = self.catalog.get_table(self.identifier)
self.table_root = self.table.table_path.rstrip('/')
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def _latest_snapshot_id(self) -> int:
latest = self.table.snapshot_manager().get_latest_snapshot()
self.assertIsNotNone(latest)
return latest.id
def test_create_branch_from_tag_lands_files_under_branch_dir(self):
snapshot_id = self._latest_snapshot_id()
self.table.create_tag("t1")
bm = self.table.branch_manager()
bm.create_branch("b1", tag_name="t1")
branch_root = f"{self.table_root}/branch/branch-b1"
self.assertTrue(os.path.isdir(branch_root))
self.assertTrue(
os.path.isfile(f"{branch_root}/snapshot/snapshot-{snapshot_id}"))
self.assertTrue(os.path.isfile(f"{branch_root}/tag/tag-t1"))
# At least one schema file (schema-0) must have been copied.
schema_dir = f"{branch_root}/schema"
self.assertTrue(os.path.isdir(schema_dir))
self.assertTrue(
any(name.startswith("schema-") for name in os.listdir(schema_dir)))
def test_fast_forward_after_create_branch_from_tag(self):
self.table.create_tag("t1")
bm = self.table.branch_manager()
bm.create_branch("b1", tag_name="t1")
# Must not raise: previously fast_forward's copy_files(src=branch_dir,
# dst=main_dir) collapsed to src == dst because the branch-side
# SnapshotManager still pointed at the main snapshot dir.
bm.fast_forward("b1")
if __name__ == '__main__':
unittest.main()