| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| """Tests for StreamReadBuilder.""" |
| |
| from unittest.mock import MagicMock |
| |
| import pytest |
| |
| from pypaimon.read.stream_read_builder import StreamReadBuilder |
| from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan |
| |
| |
| class MockEntry: |
| """Mock manifest entry for testing bucket filtering.""" |
| |
| def __init__(self, bucket): |
| self.bucket = bucket |
| |
| |
| @pytest.fixture |
| def mock_table(): |
| """Create a mock table for unit tests.""" |
| table = MagicMock() |
| table.fields = [] |
| table.options.row_tracking_enabled.return_value = False |
| return table |
| |
| |
| @pytest.fixture |
| def builder(mock_table): |
| """Create a StreamReadBuilder with mock table.""" |
| return StreamReadBuilder(mock_table) |
| |
| |
| @pytest.fixture |
| def mock_scan_table(): |
| """Create mock table for AsyncStreamingTableScan.""" |
| table = MagicMock() |
| table.options.changelog_producer.return_value = MagicMock() |
| table.file_io = MagicMock() |
| table.table_path = "/tmp/test" |
| table.fields = [] |
| table.options.row_tracking_enabled.return_value = False |
| return table |
| |
| |
| class TestStreamReadBuilderValidation: |
| """Unit tests for StreamReadBuilder method validation.""" |
| |
| def test_with_bucket_filter_valid(self, builder): |
| """Test with_bucket_filter() accepts valid filter function.""" |
| filter_fn = lambda b: b % 2 == 0 |
| result = builder.with_bucket_filter(filter_fn) |
| assert result is builder |
| assert builder._bucket_filter is filter_fn |
| |
| @pytest.mark.parametrize("bucket_ids,expected_true,expected_false", [ |
| ([0, 2, 4], [0, 2, 4], [1, 3, 5]), |
| ([], [], [0, 1, 2]), |
| ([5], [5], [0, 1, 4, 6]), |
| ]) |
| def test_with_buckets(self, builder, bucket_ids, expected_true, expected_false): |
| """Test with_buckets() creates correct filter.""" |
| builder.with_buckets(bucket_ids) |
| for b in expected_true: |
| assert builder._bucket_filter(b), f"Bucket {b} should be included" |
| for b in expected_false: |
| assert not builder._bucket_filter(b), f"Bucket {b} should be excluded" |
| |
| def test_with_consumer_id(self, builder): |
| """Test with_consumer_id() stores consumer_id and returns self.""" |
| result = builder.with_consumer_id("my-consumer") |
| assert result is builder |
| assert builder._consumer_id == "my-consumer" |
| |
| def test_method_chaining(self, builder): |
| """Test method chaining works correctly.""" |
| result = (builder |
| .with_poll_interval_ms(500) |
| .with_bucket_filter(lambda b: b % 2 == 0) |
| .with_include_row_kind(True)) |
| assert result is builder |
| assert builder._poll_interval_ms == 500 |
| assert builder._bucket_filter is not None |
| assert builder._include_row_kind is True |
| |
| |
| class TestAsyncStreamingTableScanFiltering: |
| """Test AsyncStreamingTableScan._filter_entries_for_shard().""" |
| |
| def test_filter_with_bucket_filter(self, mock_scan_table): |
| """Test _filter_entries_for_shard with custom bucket filter.""" |
| scan = AsyncStreamingTableScan( |
| table=mock_scan_table, |
| bucket_filter=lambda b: b % 2 == 0 |
| ) |
| entries = [MockEntry(b) for b in range(8)] |
| filtered = scan._filter_entries_for_shard(entries) |
| assert [e.bucket for e in filtered] == [0, 2, 4, 6] |
| |
| def test_filter_no_filter_returns_all(self, mock_scan_table): |
| """Test _filter_entries_for_shard with no filter returns all entries.""" |
| scan = AsyncStreamingTableScan(table=mock_scan_table) |
| entries = [MockEntry(b) for b in range(8)] |
| filtered = scan._filter_entries_for_shard(entries) |
| assert [e.bucket for e in filtered] == list(range(8)) |
| |
| |
| class TestConsumerIdPassthrough: |
| """Test that consumer_id passes through to AsyncStreamingTableScan.""" |
| |
| def test_new_streaming_scan_passes_consumer_id(self, mock_scan_table): |
| """new_streaming_scan() should pass consumer_id to AsyncStreamingTableScan.""" |
| builder = StreamReadBuilder(mock_scan_table) |
| builder.with_consumer_id("test-consumer") |
| |
| scan = builder.new_streaming_scan() |
| |
| assert scan._consumer_id == "test-consumer" |
| assert scan._consumer_manager is not None |
| |
| def test_new_streaming_scan_no_consumer_by_default(self, mock_scan_table): |
| """Without with_consumer_id(), scan should have no consumer.""" |
| builder = StreamReadBuilder(mock_scan_table) |
| |
| scan = builder.new_streaming_scan() |
| |
| assert scan._consumer_id is None |
| assert scan._consumer_manager is None |
| |
| |
| if __name__ == '__main__': |
| pytest.main([__file__, '-v']) |