| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """tests for Quantum Data Loader.""" |
| |
| from __future__ import annotations |
| |
| from typing import TYPE_CHECKING, cast |
| |
| import pytest |
| |
| if TYPE_CHECKING: |
| from qumat_qdp.loader import QuantumDataLoader as QuantumDataLoaderType |
| |
| try: |
| from qumat_qdp.loader import QuantumDataLoader |
| except ImportError: |
| QuantumDataLoader: type[QuantumDataLoaderType] | None = None |
| |
| |
| def _loader_available(): |
| return QuantumDataLoader is not None |
| |
| |
| def _require_loader_cls() -> type[QuantumDataLoaderType]: |
| if QuantumDataLoader is None: |
| pytest.skip("QuantumDataLoader not available") |
| return cast("type[QuantumDataLoaderType]", QuantumDataLoader) |
| |
| |
| @pytest.fixture |
| def loader_cls() -> type[QuantumDataLoaderType]: |
| """Return QuantumDataLoader class; skip test if not available (for type narrowing).""" |
| return _require_loader_cls() |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_mutual_exclusion_both_sources_raises( |
| loader_cls: type[QuantumDataLoaderType], |
| ): |
| """Calling both .source_synthetic() and .source_file() then __iter__ raises ValueError.""" |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(10, size=4) |
| .source_synthetic() |
| .source_file("/tmp/any.parquet") |
| ) |
| with pytest.raises(ValueError) as exc_info: |
| list(loader) |
| msg = str(exc_info.value) |
| assert "Cannot set both synthetic and file sources" in msg |
| assert "source_synthetic" in msg |
| assert "source_file" in msg |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_mutual_exclusion_exact_message(loader_cls: type[QuantumDataLoaderType]): |
| """ValueError when both sources set: message mentions source_synthetic and source_file.""" |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(10, size=4) |
| .source_file("/tmp/x.npy") |
| .source_synthetic() |
| ) |
| with pytest.raises(ValueError) as exc_info: |
| list(loader) |
| assert "Cannot set both synthetic and file sources" in str(exc_info.value) |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_source_file_empty_path_raises(loader_cls: type[QuantumDataLoaderType]): |
| """source_file() with empty path raises ValueError.""" |
| loader = loader_cls(device_id=0).qubits(4).batches(10, size=4) |
| with pytest.raises(ValueError) as exc_info: |
| loader.source_file("") |
| assert "path" in str(exc_info.value).lower() |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_synthetic_loader_batch_count(loader_cls: type[QuantumDataLoaderType]): |
| """Synthetic loader yields exactly total_batches batches.""" |
| total = 5 |
| batch_size = 4 |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(total, size=batch_size) |
| .source_synthetic() |
| ) |
| try: |
| batches = list(loader) |
| except RuntimeError as e: |
| if "only available on Linux" in str(e) or "not available" in str(e): |
| pytest.skip("CUDA/Linux required for loader iteration") |
| raise |
| assert len(batches) == total |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_file_loader_unsupported_extension_raises( |
| loader_cls: type[QuantumDataLoaderType], |
| ): |
| """source_file with unsupported extension raises at __iter__.""" |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(10, size=4) |
| .source_file("/tmp/data.unsupported") |
| ) |
| try: |
| list(loader) |
| except RuntimeError as e: |
| msg = str(e).lower() |
| if "not available" in msg: |
| pytest.skip( |
| "create_file_loader not available (e.g. extension built without loader)" |
| ) |
| return |
| assert "unsupported" in msg or "extension" in msg or "supported" in msg |
| return |
| except ValueError: |
| pytest.skip("Loader may validate path before Rust") |
| return |
| pytest.fail("Expected RuntimeError for unsupported file extension") |
| |
| |
| # --- Streaming (source_file(..., streaming=True)) tests --- |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_streaming_requires_parquet(loader_cls: type[QuantumDataLoaderType]): |
| """source_file(path, streaming=True) with non-.parquet path raises ValueError.""" |
| with pytest.raises(ValueError) as exc_info: |
| loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file( |
| "/tmp/data.npy", streaming=True |
| ) |
| msg = str(exc_info.value).lower() |
| assert "parquet" in msg or "streaming" in msg |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_streaming_parquet_extension_ok(loader_cls: type[QuantumDataLoaderType]): |
| """source_file(path, streaming=True) with .parquet path does not raise at builder.""" |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(10, size=4) |
| .source_file("/tmp/data.parquet", streaming=True) |
| ) |
| # Iteration may raise RuntimeError (no CUDA) or fail on missing file; we only check builder accepts. |
| assert loader._streaming_requested is True |
| assert loader._file_path == "/tmp/data.parquet" |
| |
| |
| # --- NullHandling builder tests --- |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_null_handling_fill_zero(loader_cls: type[QuantumDataLoaderType]): |
| """null_handling('fill_zero') sets the field correctly.""" |
| loader = ( |
| loader_cls(device_id=0).qubits(4).batches(10, size=4).null_handling("fill_zero") |
| ) |
| assert loader._null_handling == "fill_zero" |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_null_handling_reject(loader_cls: type[QuantumDataLoaderType]): |
| """null_handling('reject') sets the field correctly.""" |
| loader = ( |
| loader_cls(device_id=0).qubits(4).batches(10, size=4).null_handling("reject") |
| ) |
| assert loader._null_handling == "reject" |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_null_handling_invalid_raises(loader_cls: type[QuantumDataLoaderType]): |
| """null_handling with an invalid string raises ValueError.""" |
| with pytest.raises(ValueError) as exc_info: |
| loader_cls(device_id=0).null_handling("invalid_policy") |
| msg = str(exc_info.value) |
| assert "fill_zero" in msg or "reject" in msg |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| def test_null_handling_default_is_none(loader_cls: type[QuantumDataLoaderType]): |
| """By default, _null_handling is None (Rust will use FillZero).""" |
| loader = loader_cls(device_id=0) |
| assert loader._null_handling is None |
| |
| |
| # --- Remote URL (source_file) builder tests --- |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| @pytest.mark.parametrize( |
| ("path", "streaming"), |
| [ |
| ("s3://my-bucket/data.parquet", False), |
| ("s3://bucket/path/to/data.parquet", True), |
| ("s3://bucket/data.npy", False), |
| ("gs://my-bucket/data.parquet", False), |
| ("gs://bucket/path/to/data.parquet", True), |
| ("gs://bucket/data.npy", False), |
| ], |
| ids=[ |
| "parquet-no-stream", |
| "parquet-stream", |
| "npy-no-stream", |
| "gcs-parquet-no-stream", |
| "gcs-parquet-stream", |
| "gcs-npy-no-stream", |
| ], |
| ) |
| def test_source_file_remote_url_accepted(path, streaming): |
| """source_file() accepts valid remote URLs at builder level.""" |
| loader_cls = _require_loader_cls() |
| loader = ( |
| loader_cls(device_id=0) |
| .qubits(4) |
| .batches(10, size=4) |
| .source_file(path, streaming=streaming) |
| ) |
| assert loader._file_path == path |
| assert loader._file_requested is True |
| assert loader._streaming_requested is streaming |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| @pytest.mark.parametrize( |
| "path", |
| [ |
| "s3://bucket/data.npy", |
| "gs://bucket/data.npy", |
| ], |
| ids=["s3-npy", "gcs-npy"], |
| ) |
| def test_source_file_remote_streaming_non_parquet_raises(path): |
| """source_file(remote://..., streaming=True) with non-.parquet raises ValueError.""" |
| loader_cls = _require_loader_cls() |
| with pytest.raises(ValueError) as exc_info: |
| loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file( |
| path, streaming=True |
| ) |
| msg = str(exc_info.value).lower() |
| assert "parquet" in msg or "streaming" in msg |
| |
| |
| @pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available") |
| @pytest.mark.parametrize( |
| "path", |
| [ |
| "s3://bucket/data.parquet?versionId=abc", |
| "s3://bucket/data.parquet#v1", |
| "gs://bucket/data.parquet?generation=123", |
| "gs://bucket/data.parquet#v2", |
| ], |
| ids=["s3-query", "s3-fragment", "gcs-query", "gcs-fragment"], |
| ) |
| def test_source_file_remote_query_fragment_raises(path): |
| """source_file(remote://...?... or ...#...) raises ValueError.""" |
| loader_cls = _require_loader_cls() |
| with pytest.raises(ValueError) as exc_info: |
| loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file(path) |
| msg = str(exc_info.value).lower() |
| assert "query" in msg or "fragment" in msg or "scheme://bucket/key" in msg |