blob: 283d67322e266b8e250a653e88e101ea004c404f [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tests for Quantum Data Loader."""
from __future__ import annotations
from typing import TYPE_CHECKING, cast
import pytest
if TYPE_CHECKING:
from qumat_qdp.loader import QuantumDataLoader as QuantumDataLoaderType
try:
from qumat_qdp.loader import QuantumDataLoader
except ImportError:
QuantumDataLoader: type[QuantumDataLoaderType] | None = None
def _loader_available():
return QuantumDataLoader is not None
def _require_loader_cls() -> type[QuantumDataLoaderType]:
if QuantumDataLoader is None:
pytest.skip("QuantumDataLoader not available")
return cast("type[QuantumDataLoaderType]", QuantumDataLoader)
@pytest.fixture
def loader_cls() -> type[QuantumDataLoaderType]:
"""Return QuantumDataLoader class; skip test if not available (for type narrowing)."""
return _require_loader_cls()
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_mutual_exclusion_both_sources_raises(
loader_cls: type[QuantumDataLoaderType],
):
"""Calling both .source_synthetic() and .source_file() then __iter__ raises ValueError."""
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(10, size=4)
.source_synthetic()
.source_file("/tmp/any.parquet")
)
with pytest.raises(ValueError) as exc_info:
list(loader)
msg = str(exc_info.value)
assert "Cannot set both synthetic and file sources" in msg
assert "source_synthetic" in msg
assert "source_file" in msg
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_mutual_exclusion_exact_message(loader_cls: type[QuantumDataLoaderType]):
"""ValueError when both sources set: message mentions source_synthetic and source_file."""
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(10, size=4)
.source_file("/tmp/x.npy")
.source_synthetic()
)
with pytest.raises(ValueError) as exc_info:
list(loader)
assert "Cannot set both synthetic and file sources" in str(exc_info.value)
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_source_file_empty_path_raises(loader_cls: type[QuantumDataLoaderType]):
"""source_file() with empty path raises ValueError."""
loader = loader_cls(device_id=0).qubits(4).batches(10, size=4)
with pytest.raises(ValueError) as exc_info:
loader.source_file("")
assert "path" in str(exc_info.value).lower()
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_synthetic_loader_batch_count(loader_cls: type[QuantumDataLoaderType]):
"""Synthetic loader yields exactly total_batches batches."""
total = 5
batch_size = 4
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(total, size=batch_size)
.source_synthetic()
)
try:
batches = list(loader)
except RuntimeError as e:
if "only available on Linux" in str(e) or "not available" in str(e):
pytest.skip("CUDA/Linux required for loader iteration")
raise
assert len(batches) == total
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_file_loader_unsupported_extension_raises(
loader_cls: type[QuantumDataLoaderType],
):
"""source_file with unsupported extension raises at __iter__."""
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(10, size=4)
.source_file("/tmp/data.unsupported")
)
try:
list(loader)
except RuntimeError as e:
msg = str(e).lower()
if "not available" in msg:
pytest.skip(
"create_file_loader not available (e.g. extension built without loader)"
)
return
assert "unsupported" in msg or "extension" in msg or "supported" in msg
return
except ValueError:
pytest.skip("Loader may validate path before Rust")
return
pytest.fail("Expected RuntimeError for unsupported file extension")
# --- Streaming (source_file(..., streaming=True)) tests ---
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_streaming_requires_parquet(loader_cls: type[QuantumDataLoaderType]):
"""source_file(path, streaming=True) with non-.parquet path raises ValueError."""
with pytest.raises(ValueError) as exc_info:
loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file(
"/tmp/data.npy", streaming=True
)
msg = str(exc_info.value).lower()
assert "parquet" in msg or "streaming" in msg
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_streaming_parquet_extension_ok(loader_cls: type[QuantumDataLoaderType]):
"""source_file(path, streaming=True) with .parquet path does not raise at builder."""
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(10, size=4)
.source_file("/tmp/data.parquet", streaming=True)
)
# Iteration may raise RuntimeError (no CUDA) or fail on missing file; we only check builder accepts.
assert loader._streaming_requested is True
assert loader._file_path == "/tmp/data.parquet"
# --- NullHandling builder tests ---
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_null_handling_fill_zero(loader_cls: type[QuantumDataLoaderType]):
"""null_handling('fill_zero') sets the field correctly."""
loader = (
loader_cls(device_id=0).qubits(4).batches(10, size=4).null_handling("fill_zero")
)
assert loader._null_handling == "fill_zero"
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_null_handling_reject(loader_cls: type[QuantumDataLoaderType]):
"""null_handling('reject') sets the field correctly."""
loader = (
loader_cls(device_id=0).qubits(4).batches(10, size=4).null_handling("reject")
)
assert loader._null_handling == "reject"
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_null_handling_invalid_raises(loader_cls: type[QuantumDataLoaderType]):
"""null_handling with an invalid string raises ValueError."""
with pytest.raises(ValueError) as exc_info:
loader_cls(device_id=0).null_handling("invalid_policy")
msg = str(exc_info.value)
assert "fill_zero" in msg or "reject" in msg
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
def test_null_handling_default_is_none(loader_cls: type[QuantumDataLoaderType]):
"""By default, _null_handling is None (Rust will use FillZero)."""
loader = loader_cls(device_id=0)
assert loader._null_handling is None
# --- Remote URL (source_file) builder tests ---
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
@pytest.mark.parametrize(
("path", "streaming"),
[
("s3://my-bucket/data.parquet", False),
("s3://bucket/path/to/data.parquet", True),
("s3://bucket/data.npy", False),
("gs://my-bucket/data.parquet", False),
("gs://bucket/path/to/data.parquet", True),
("gs://bucket/data.npy", False),
],
ids=[
"parquet-no-stream",
"parquet-stream",
"npy-no-stream",
"gcs-parquet-no-stream",
"gcs-parquet-stream",
"gcs-npy-no-stream",
],
)
def test_source_file_remote_url_accepted(path, streaming):
"""source_file() accepts valid remote URLs at builder level."""
loader_cls = _require_loader_cls()
loader = (
loader_cls(device_id=0)
.qubits(4)
.batches(10, size=4)
.source_file(path, streaming=streaming)
)
assert loader._file_path == path
assert loader._file_requested is True
assert loader._streaming_requested is streaming
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
@pytest.mark.parametrize(
"path",
[
"s3://bucket/data.npy",
"gs://bucket/data.npy",
],
ids=["s3-npy", "gcs-npy"],
)
def test_source_file_remote_streaming_non_parquet_raises(path):
"""source_file(remote://..., streaming=True) with non-.parquet raises ValueError."""
loader_cls = _require_loader_cls()
with pytest.raises(ValueError) as exc_info:
loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file(
path, streaming=True
)
msg = str(exc_info.value).lower()
assert "parquet" in msg or "streaming" in msg
@pytest.mark.skipif(not _loader_available(), reason="QuantumDataLoader not available")
@pytest.mark.parametrize(
"path",
[
"s3://bucket/data.parquet?versionId=abc",
"s3://bucket/data.parquet#v1",
"gs://bucket/data.parquet?generation=123",
"gs://bucket/data.parquet#v2",
],
ids=["s3-query", "s3-fragment", "gcs-query", "gcs-fragment"],
)
def test_source_file_remote_query_fragment_raises(path):
"""source_file(remote://...?... or ...#...) raises ValueError."""
loader_cls = _require_loader_cls()
with pytest.raises(ValueError) as exc_info:
loader_cls(device_id=0).qubits(4).batches(10, size=4).source_file(path)
msg = str(exc_info.value).lower()
assert "query" in msg or "fragment" in msg or "scheme://bucket/key" in msg