blob: 78c75cebddf6ed6a013e3d9aa96fda9f00828668 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import uuid
from stat import S_ISDIR, S_ISREG
from typing import Any, ClassVar
from unittest import mock
import pytest
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.memory import MemoryFileSystem
from airflow.sdk import Asset, ObjectStoragePath
from airflow.sdk.io import attach
from airflow.sdk.io.store import _STORE_CACHE, ObjectStore
from airflow.sdk.module_loading import qualname
def test_init():
path = ObjectStoragePath("s3://bucket/key/part1/part2")
assert path.bucket == "bucket"
assert path.key == "key/part1/part2"
assert path.protocol == "s3"
assert path.path == "bucket/key/part1/part2"
path2 = ObjectStoragePath(path / "part3")
assert path2.bucket == "bucket"
assert path2.key == "key/part1/part2/part3"
assert path2.protocol == "s3"
assert path2.path == "bucket/key/part1/part2/part3"
path3 = ObjectStoragePath(path2 / "2023")
assert path3.bucket == "bucket"
assert path3.key == "key/part1/part2/part3/2023"
assert path3.protocol == "s3"
assert path3.path == "bucket/key/part1/part2/part3/2023"
@pytest.mark.parametrize("input_str", ("file:///tmp/foo", "s3://conn_id@bucket/test.txt"))
def test_str(input_str):
o = ObjectStoragePath(input_str)
assert str(o) == input_str
def test_cwd():
assert ObjectStoragePath.cwd()
def test_home():
assert ObjectStoragePath.home()
def test_lazy_load():
o = ObjectStoragePath("file:///tmp/foo")
with pytest.raises(AttributeError):
assert o._fs_cached
assert o.fs is not None
assert o._fs_cached
# Clear the cache to avoid side effects in other tests below
_STORE_CACHE.clear()
class _FakeRemoteFileSystem(MemoryFileSystem):
protocol = ("s3", "fakefs", "ffs", "ffs2")
root_marker = ""
store: ClassVar[dict[str, Any]] = {}
pseudo_dirs = [""]
def __init__(self, *args, **kwargs):
self.conn_id = kwargs.pop("conn_id", None)
super().__init__(*args, **kwargs)
@classmethod
def _strip_protocol(cls, path):
for protocol in cls.protocol:
if path.startswith(f"{protocol}://"):
return path[len(f"{protocol}://") :]
if "::" in path or "://" in path:
return path.rstrip("/")
path = path.lstrip("/").rstrip("/")
return path
class TestAttach:
FAKE = "ffs:///fake"
MNT = "ffs:///mnt/warehouse"
FOO = "ffs:///mnt/warehouse/foo"
BAR = FOO
@pytest.fixture(autouse=True)
def restore_cache(self):
cache = _STORE_CACHE.copy()
yield
_STORE_CACHE.clear()
_STORE_CACHE.update(cache)
@pytest.fixture
def fake_files(self):
obj = _FakeRemoteFileSystem()
obj.touch(self.FOO)
try:
yield
finally:
_FakeRemoteFileSystem.store.clear()
_FakeRemoteFileSystem.pseudo_dirs[:] = [""]
def test_alias(self):
store = attach("file", alias="local")
assert isinstance(store.fs, LocalFileSystem)
assert {"local": store} == _STORE_CACHE
def test_objectstoragepath_init_conn_id_in_uri(self):
attach(protocol="fake", conn_id="fake", fs=_FakeRemoteFileSystem(conn_id="fake"))
p = ObjectStoragePath("fake://fake@bucket/path")
p.touch()
fsspec_info = p.fs.info(p.path)
assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": "fake"}
@pytest.mark.parametrize(
"fn, args, fn2, path, expected_args, expected_kwargs",
[
("checksum", {}, "checksum", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}),
("size", {}, "size", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}),
(
"sign",
{"expiration": 200, "extra": "xtra"},
"sign",
FOO,
_FakeRemoteFileSystem._strip_protocol(BAR),
{"expiration": 200, "extra": "xtra"},
),
("ukey", {}, "ukey", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}),
(
"read_block",
{"offset": 0, "length": 1},
"read_block",
FOO,
_FakeRemoteFileSystem._strip_protocol(BAR),
{"delimiter": None, "length": 1, "offset": 0},
),
],
)
def test_standard_extended_api(self, fake_files, fn, args, fn2, path, expected_args, expected_kwargs):
fs = _FakeRemoteFileSystem()
attach(protocol="ffs", conn_id="fake", fs=fs)
with mock.patch.object(fs, fn2) as method:
o = ObjectStoragePath(path, conn_id="fake")
getattr(o, fn)(**args)
method.assert_called_once_with(expected_args, **expected_kwargs)
class TestRemotePath:
@pytest.fixture(autouse=True)
def fake_fs(self, monkeypatch):
monkeypatch.setattr(ObjectStoragePath, "_fs_factory", lambda *a, **k: _FakeRemoteFileSystem())
def test_bucket_key_protocol(self):
bucket = "bkt"
key = "yek"
protocol = "s3"
o = ObjectStoragePath(f"{protocol}://{bucket}/{key}")
assert o.bucket == bucket
assert o.container == bucket
assert o.key == key
assert o.protocol == protocol
class TestLocalPath:
@pytest.fixture
def target(self, tmp_path):
tmp = tmp_path.joinpath(str(uuid.uuid4()))
tmp.touch()
return tmp.as_posix()
@pytest.fixture
def another(self, tmp_path):
tmp = tmp_path.joinpath(str(uuid.uuid4()))
tmp.touch()
return tmp.as_posix()
def test_ls(self, tmp_path, target):
d = ObjectStoragePath(f"file://{tmp_path.as_posix()}")
o = ObjectStoragePath(f"file://{target}")
data = list(d.iterdir())
assert len(data) == 1
assert data[0] == o
d.rmdir(recursive=True)
assert not o.exists()
def test_read_write(self, target):
o = ObjectStoragePath(f"file://{target}")
with o.open("wb") as f:
f.write(b"foo")
assert o.open("rb").read() == b"foo"
o.unlink()
def test_stat(self, target):
o = ObjectStoragePath(f"file://{target}")
assert o.stat().st_size == 0
assert S_ISREG(o.stat().st_mode)
assert S_ISDIR(o.parent.stat().st_mode)
def test_replace(self, target, another):
o = ObjectStoragePath(f"file://{target}")
i = ObjectStoragePath(f"file://{another}")
assert i.size() == 0
txt = "foo"
o.write_text(txt)
e = o.replace(i)
assert o.exists() is False
assert i == e
assert e.size() == len(txt)
def test_hash(self, target, another):
file_uri_1 = f"file://{target}"
file_uri_2 = f"file://{another}"
s = set()
for _ in range(10):
s.add(ObjectStoragePath(file_uri_1))
s.add(ObjectStoragePath(file_uri_2))
assert len(s) == 2
def test_is_relative_to(self, tmp_path, target):
o1 = ObjectStoragePath(f"file://{target}")
o2 = ObjectStoragePath(f"file://{tmp_path.as_posix()}")
o3 = ObjectStoragePath(f"file:///{uuid.uuid4()}")
assert o1.is_relative_to(o2)
assert not o1.is_relative_to(o3)
def test_relative_to(self, tmp_path, target):
o1 = ObjectStoragePath(f"file://{target}")
o2 = ObjectStoragePath(f"file://{tmp_path.as_posix()}")
o3 = ObjectStoragePath(f"file:///{uuid.uuid4()}")
assert o1.relative_to(o2) == o1
with pytest.raises(ValueError, match="is not in the subpath of"):
o1.relative_to(o3)
def test_asset(self):
p = "s3"
f = "bucket/object"
i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"})
o = ObjectStoragePath(i)
assert o.protocol == p
assert o.path == f
def test_move_local(self, hook_lineage_collector, tmp_path, target):
o1 = ObjectStoragePath(f"file://{target}")
o2 = ObjectStoragePath(f"file://{tmp_path}/{uuid.uuid4()}")
assert o1.exists()
assert not o2.exists()
o1.move(o2)
assert o2.exists()
assert not o1.exists()
collected_assets = hook_lineage_collector.collected_assets
assert len(collected_assets.inputs) == 1
assert len(collected_assets.outputs) == 1
assert collected_assets.inputs[0].asset.uri == str(o1)
assert collected_assets.outputs[0].asset.uri == str(o2)
def test_serde_objectstoragepath(self):
path = "file:///bucket/key/part1/part2"
o = ObjectStoragePath(path)
s = o.serialize()
assert s["path"] == path
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
o = ObjectStoragePath(path, my_setting="foo")
s = o.serialize()
assert "my_setting" in s["kwargs"]
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
store = attach("file", conn_id="mock")
o = ObjectStoragePath(path, store=store)
s = o.serialize()
assert s["kwargs"]["store"] == store
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
def test_serde_store(self):
store = attach("file", conn_id="mock")
s = store.serialize()
d = ObjectStore.deserialize(s, 1)
assert s["protocol"] == "file"
assert s["conn_id"] == "mock"
assert s["filesystem"] == qualname(LocalFileSystem)
assert store == d
store = attach("localfs", fs=LocalFileSystem())
s = store.serialize()
d = ObjectStore.deserialize(s, 1)
assert s["protocol"] == "localfs"
assert s["conn_id"] is None
assert s["filesystem"] == qualname(LocalFileSystem)
assert store == d
class TestBackwardsCompatibility:
@pytest.fixture(autouse=True)
def reset(self):
from airflow.sdk.io.fs import _register_filesystems
_register_filesystems.cache_clear()
yield
_register_filesystems.cache_clear()
def test_backwards_compat(self):
from airflow.io import _BUILTIN_SCHEME_TO_FS, get_fs
def get_fs_no_storage_options(_: str):
return LocalFileSystem()
_BUILTIN_SCHEME_TO_FS["file"] = get_fs_no_storage_options # type: ignore[call-arg]
assert get_fs("file")
with pytest.raises(AttributeError):
get_fs("file", storage_options={"foo": "bar"})