| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| """Tests for Consumer and ConsumerManager.""" |
| |
| import json |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| from unittest.mock import Mock |
| |
| from pypaimon.consumer.consumer import Consumer |
| from pypaimon.consumer.consumer_manager import ConsumerManager |
| |
| |
| class ConsumerTest(unittest.TestCase): |
| """Tests for Consumer data class.""" |
| |
| def test_consumer_to_json(self): |
| """Consumer should serialize to JSON with nextSnapshot field.""" |
| consumer = Consumer(next_snapshot=42) |
| json_str = consumer.to_json() |
| |
| # Parse and verify |
| data = json.loads(json_str) |
| self.assertEqual(data["nextSnapshot"], 42) |
| |
| def test_consumer_from_json(self): |
| """Consumer should deserialize from JSON.""" |
| json_str = '{"nextSnapshot": 42}' |
| consumer = Consumer.from_json(json_str) |
| |
| self.assertEqual(consumer.next_snapshot, 42) |
| |
| def test_consumer_from_json_ignores_unknown_fields(self): |
| """Consumer should ignore unknown fields in JSON.""" |
| json_str = '{"nextSnapshot": 42, "unknownField": "value"}' |
| consumer = Consumer.from_json(json_str) |
| |
| self.assertEqual(consumer.next_snapshot, 42) |
| |
| def test_consumer_roundtrip(self): |
| """Consumer should survive JSON roundtrip.""" |
| original = Consumer(next_snapshot=12345) |
| json_str = original.to_json() |
| restored = Consumer.from_json(json_str) |
| |
| self.assertEqual(restored.next_snapshot, original.next_snapshot) |
| |
| |
| class ConsumerManagerTest(unittest.TestCase): |
| """Tests for ConsumerManager.""" |
| |
| def setUp(self): |
| """Create a temporary directory for testing.""" |
| self.tempdir = tempfile.mkdtemp() |
| self.table_path = os.path.join(self.tempdir, "test_table") |
| os.makedirs(self.table_path) |
| |
| # Create mock file_io |
| self.file_io = Mock() |
| self._setup_file_io_mock() |
| |
| def tearDown(self): |
| """Clean up temporary directory.""" |
| shutil.rmtree(self.tempdir, ignore_errors=True) |
| |
| def _setup_file_io_mock(self): |
| """Setup file_io mock to use real filesystem.""" |
| def read_file_utf8(path): |
| with open(path, 'r') as f: |
| return f.read() |
| |
| def overwrite_file_utf8(path, content): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| with open(path, 'w') as f: |
| f.write(content) |
| |
| def exists(path): |
| return os.path.exists(path) |
| |
| def delete_quietly(path): |
| if os.path.exists(path): |
| os.remove(path) |
| |
| self.file_io.read_file_utf8 = read_file_utf8 |
| self.file_io.overwrite_file_utf8 = overwrite_file_utf8 |
| self.file_io.exists = exists |
| self.file_io.delete_quietly = delete_quietly |
| |
| def test_consumer_manager_reset_consumer(self): |
| """reset_consumer should write consumer state to file.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| consumer = Consumer(next_snapshot=42) |
| |
| manager.reset_consumer("my-consumer", consumer) |
| |
| # Verify file exists |
| consumer_file = os.path.join(self.table_path, "consumer", "consumer-my-consumer") |
| self.assertTrue(os.path.exists(consumer_file)) |
| |
| # Verify content |
| with open(consumer_file, 'r') as f: |
| content = f.read() |
| data = json.loads(content) |
| self.assertEqual(data["nextSnapshot"], 42) |
| |
| def test_consumer_manager_get_consumer(self): |
| """consumer() should read consumer state from file.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| |
| # Write consumer file directly |
| consumer_dir = os.path.join(self.table_path, "consumer") |
| os.makedirs(consumer_dir, exist_ok=True) |
| consumer_file = os.path.join(consumer_dir, "consumer-my-consumer") |
| with open(consumer_file, 'w') as f: |
| f.write('{"nextSnapshot": 42}') |
| |
| # Read via manager |
| consumer = manager.consumer("my-consumer") |
| |
| self.assertIsNotNone(consumer) |
| self.assertEqual(consumer.next_snapshot, 42) |
| |
| def test_consumer_manager_get_nonexistent_consumer(self): |
| """consumer() should return None for non-existent consumer.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| |
| consumer = manager.consumer("nonexistent") |
| |
| self.assertIsNone(consumer) |
| |
| def test_consumer_manager_delete_consumer(self): |
| """delete_consumer should remove consumer file.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| |
| # Create consumer first |
| manager.reset_consumer("my-consumer", Consumer(next_snapshot=42)) |
| consumer_file = os.path.join(self.table_path, "consumer", "consumer-my-consumer") |
| self.assertTrue(os.path.exists(consumer_file)) |
| |
| # Delete |
| manager.delete_consumer("my-consumer") |
| |
| self.assertFalse(os.path.exists(consumer_file)) |
| |
| def test_consumer_manager_update_consumer(self): |
| """reset_consumer should update existing consumer.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| |
| # Create initial consumer |
| manager.reset_consumer("my-consumer", Consumer(next_snapshot=42)) |
| |
| # Update |
| manager.reset_consumer("my-consumer", Consumer(next_snapshot=100)) |
| |
| # Verify updated |
| consumer = manager.consumer("my-consumer") |
| self.assertEqual(consumer.next_snapshot, 100) |
| |
| def test_consumer_path(self): |
| """Consumer files should be in {table_path}/consumer/consumer-{id}.""" |
| manager = ConsumerManager(self.file_io, self.table_path) |
| |
| path = manager._consumer_path("test-id") |
| |
| expected = f"{self.table_path}/consumer/consumer-test-id" |
| self.assertEqual(path, expected) |
| |
| def test_validate_rejects_empty(self): |
| manager = ConsumerManager(self.file_io, self.table_path) |
| with self.assertRaises(ValueError): |
| manager._consumer_path("") |
| |
| def test_validate_rejects_path_separators(self): |
| manager = ConsumerManager(self.file_io, self.table_path) |
| for bad_id in ("foo/bar", "foo\\bar"): |
| with self.assertRaises(ValueError, msg=bad_id): |
| manager._consumer_path(bad_id) |
| |
| def test_validate_rejects_relative_components(self): |
| manager = ConsumerManager(self.file_io, self.table_path) |
| for bad_id in (".", ".."): |
| with self.assertRaises(ValueError, msg=bad_id): |
| manager._consumer_path(bad_id) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |