blob: 1ccdfbfb7980557d7f1dacde22706676335283f3 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import sys
import uuid
from stat import S_ISDIR, S_ISREG
from tempfile import NamedTemporaryFile
from typing import Any, ClassVar
from unittest import mock
import pytest
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.memory import MemoryFileSystem
from fsspec.registry import _registry as _fsspec_registry, register_implementation
from airflow.datasets import Dataset
from airflow.io import _register_filesystems, get_fs
from airflow.io.path import ObjectStoragePath
from airflow.io.store import _STORE_CACHE, ObjectStore, attach
from airflow.utils.module_loading import qualname
FAKE = "file:///fake"
MNT = "file:///mnt/warehouse"
FOO = "file:///mnt/warehouse/foo"
BAR = FOO
class FakeLocalFileSystem(MemoryFileSystem):
protocol = ("file", "local")
root_marker = "/"
store: ClassVar[dict[str, Any]] = {}
pseudo_dirs = [""]
def __init__(self, *args, **kwargs):
self.conn_id = kwargs.pop("conn_id", None)
super().__init__(*args, **kwargs)
@classmethod
def _strip_protocol(cls, path):
for protocol in cls.protocol:
if path.startswith(f"{protocol}://"):
return path[len(f"{protocol}://") :]
if "::" in path or "://" in path:
return path.rstrip("/")
path = path.lstrip("/").rstrip("/")
return path
class FakeRemoteFileSystem(MemoryFileSystem):
protocol = ("s3", "fakefs", "ffs", "ffs2")
root_marker = ""
store: ClassVar[dict[str, Any]] = {}
pseudo_dirs = [""]
def __init__(self, *args, **kwargs):
self.conn_id = kwargs.pop("conn_id", None)
super().__init__(*args, **kwargs)
@classmethod
def _strip_protocol(cls, path):
for protocol in cls.protocol:
if path.startswith(f"{protocol}://"):
return path[len(f"{protocol}://") :]
if "::" in path or "://" in path:
return path.rstrip("/")
path = path.lstrip("/").rstrip("/")
return path
def get_fs_no_storage_options(_: str):
return LocalFileSystem()
class TestFs:
def setup_class(self):
self._store_cache = _STORE_CACHE.copy()
self._fsspec_registry = _fsspec_registry.copy()
for protocol in FakeRemoteFileSystem.protocol:
register_implementation(protocol, FakeRemoteFileSystem, clobber=True)
def teardown(self):
_STORE_CACHE.clear()
_STORE_CACHE.update(self._store_cache)
_fsspec_registry.clear()
_fsspec_registry.update(self._fsspec_registry)
def test_alias(self):
store = attach("file", alias="local")
assert isinstance(store.fs, LocalFileSystem)
assert "local" in _STORE_CACHE
def test_init_objectstoragepath(self):
attach("s3", fs=FakeRemoteFileSystem())
path = ObjectStoragePath("s3://bucket/key/part1/part2")
assert path.bucket == "bucket"
assert path.key == "key/part1/part2"
assert path.protocol == "s3"
assert path.path == "bucket/key/part1/part2"
path2 = ObjectStoragePath(path / "part3")
assert path2.bucket == "bucket"
assert path2.key == "key/part1/part2/part3"
assert path2.protocol == "s3"
assert path2.path == "bucket/key/part1/part2/part3"
path3 = ObjectStoragePath(path2 / "2023")
assert path3.bucket == "bucket"
assert path3.key == "key/part1/part2/part3/2023"
assert path3.protocol == "s3"
assert path3.path == "bucket/key/part1/part2/part3/2023"
def test_read_write(self):
o = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
with o.open("wb") as f:
f.write(b"foo")
assert o.open("rb").read() == b"foo"
o.unlink()
def test_ls(self):
dirname = str(uuid.uuid4())
filename = str(uuid.uuid4())
d = ObjectStoragePath(f"file:///tmp/{dirname}")
d.mkdir(parents=True)
o = d / filename
o.touch()
data = list(d.iterdir())
assert len(data) == 1
assert data[0] == o
d.rmdir(recursive=True)
assert not o.exists()
def test_objectstoragepath_init_conn_id_in_uri(self):
attach(protocol="fake", conn_id="fake", fs=FakeRemoteFileSystem(conn_id="fake"))
p = ObjectStoragePath("fake://fake@bucket/path")
p.touch()
fsspec_info = p.fs.info(p.path)
assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": "fake"}
@pytest.fixture
def fake_local_files(self):
obj = FakeLocalFileSystem()
obj.touch(FOO)
try:
yield
finally:
FakeLocalFileSystem.store.clear()
FakeLocalFileSystem.pseudo_dirs[:] = [""]
@pytest.mark.parametrize(
"fn, args, fn2, path, expected_args, expected_kwargs",
[
("checksum", {}, "checksum", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}),
("size", {}, "size", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}),
(
"sign",
{"expiration": 200, "extra": "xtra"},
"sign",
FOO,
FakeLocalFileSystem._strip_protocol(BAR),
{"expiration": 200, "extra": "xtra"},
),
("ukey", {}, "ukey", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}),
(
"read_block",
{"offset": 0, "length": 1},
"read_block",
FOO,
FakeLocalFileSystem._strip_protocol(BAR),
{"delimiter": None, "length": 1, "offset": 0},
),
],
)
def test_standard_extended_api(
self, fake_local_files, fn, args, fn2, path, expected_args, expected_kwargs
):
fs = FakeLocalFileSystem()
with mock.patch.object(fs, fn2) as method:
attach(protocol="file", conn_id="fake", fs=fs)
o = ObjectStoragePath(path, conn_id="fake")
getattr(o, fn)(**args)
method.assert_called_once_with(expected_args, **expected_kwargs)
def test_stat(self):
with NamedTemporaryFile() as f:
o = ObjectStoragePath(f"file://{f.name}")
assert o.stat().st_size == 0
assert S_ISREG(o.stat().st_mode)
assert S_ISDIR(o.parent.stat().st_mode)
def test_bucket_key_protocol(self):
attach(protocol="s3", fs=FakeRemoteFileSystem())
bucket = "bkt"
key = "yek"
protocol = "s3"
o = ObjectStoragePath(f"{protocol}://{bucket}/{key}")
assert o.bucket == bucket
assert o.container == bucket
assert o.key == f"{key}"
assert o.protocol == protocol
def test_cwd_home(self):
assert ObjectStoragePath.cwd()
assert ObjectStoragePath.home()
def test_replace(self):
o = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
i = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
o.touch()
i.touch()
assert i.size() == 0
txt = "foo"
o.write_text(txt)
e = o.replace(i)
assert o.exists() is False
assert i == e
assert e.size() == len(txt)
e.unlink()
@pytest.mark.skipif(sys.version_info < (3, 9), reason="`is_relative_to` new in version 3.9")
def test_is_relative_to(self):
uuid_dir = f"/tmp/{str(uuid.uuid4())}"
o1 = ObjectStoragePath(f"file://{uuid_dir}/aaa")
o2 = ObjectStoragePath(f"file://{uuid_dir}")
o3 = ObjectStoragePath(f"file://{str(uuid.uuid4())}")
assert o1.is_relative_to(o2)
assert not o1.is_relative_to(o3)
def test_relative_to(self):
uuid_dir = f"/tmp/{str(uuid.uuid4())}"
o1 = ObjectStoragePath(f"file://{uuid_dir}/aaa")
o2 = ObjectStoragePath(f"file://{uuid_dir}")
o3 = ObjectStoragePath(f"file://{str(uuid.uuid4())}")
_ = o1.relative_to(o2) # Should not raise any error
with pytest.raises(ValueError):
o1.relative_to(o3)
def test_move_local(self):
_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
_to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
_from.touch()
_from.move(_to)
assert _to.exists()
assert not _from.exists()
_to.unlink()
def test_move_remote(self):
attach("fakefs", fs=FakeRemoteFileSystem())
_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
print(_from)
_to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}")
print(_to)
_from.touch()
_from.move(_to)
assert not _from.exists()
assert _to.exists()
_to.unlink()
def test_copy_remote_remote(self):
attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))
dir_src = f"bucket1/{str(uuid.uuid4())}"
dir_dst = f"bucket2/{str(uuid.uuid4())}"
key = "foo/bar/baz.txt"
_from = ObjectStoragePath(f"ffs://{dir_src}")
_from_file = _from / key
_from_file.touch()
assert _from.bucket == "bucket1"
assert _from_file.exists()
_to = ObjectStoragePath(f"ffs2://{dir_dst}")
_from.copy(_to)
assert _to.bucket == "bucket2"
assert _to.exists()
assert _to.is_dir()
assert (_to / _from.key / key).exists()
assert (_to / _from.key / key).is_file()
_from.rmdir(recursive=True)
_to.rmdir(recursive=True)
def test_serde_objectstoragepath(self):
path = "file:///bucket/key/part1/part2"
o = ObjectStoragePath(path)
s = o.serialize()
assert s["path"] == path
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
o = ObjectStoragePath(path, my_setting="foo")
s = o.serialize()
assert "my_setting" in s["kwargs"]
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
store = attach("filex", conn_id="mock")
o = ObjectStoragePath(path, store=store)
s = o.serialize()
assert s["kwargs"]["store"] == store
d = ObjectStoragePath.deserialize(s, 1)
assert o == d
def test_serde_store(self):
store = attach("file", conn_id="mock")
s = store.serialize()
d = ObjectStore.deserialize(s, 1)
assert s["protocol"] == "file"
assert s["conn_id"] == "mock"
assert s["filesystem"] == qualname(LocalFileSystem)
assert store == d
store = attach("localfs", fs=LocalFileSystem())
s = store.serialize()
d = ObjectStore.deserialize(s, 1)
assert s["protocol"] == "localfs"
assert s["conn_id"] is None
assert s["filesystem"] == qualname(LocalFileSystem)
assert store == d
def test_backwards_compat(self):
_register_filesystems.cache_clear()
from airflow.io import _BUILTIN_SCHEME_TO_FS as SCHEMES
try:
SCHEMES["file"] = get_fs_no_storage_options # type: ignore[call-arg]
assert get_fs("file")
with pytest.raises(AttributeError):
get_fs("file", storage_options={"foo": "bar"})
finally:
# Reset the cache to avoid side effects
_register_filesystems.cache_clear()
def test_dataset(self):
attach("s3", fs=FakeRemoteFileSystem())
p = "s3"
f = "/tmp/foo"
i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"})
o = ObjectStoragePath(i)
assert o.protocol == p
assert o.path == f
def test_hash(self):
file_uri_1 = f"file:///tmp/{str(uuid.uuid4())}"
file_uri_2 = f"file:///tmp/{str(uuid.uuid4())}"
s = set()
for _ in range(10):
s.add(ObjectStoragePath(file_uri_1))
s.add(ObjectStoragePath(file_uri_2))
assert len(s) == 2
def test_lazy_load(self):
o = ObjectStoragePath("file:///tmp/foo")
with pytest.raises(AttributeError):
assert o._fs_cached
assert o.fs is not None
assert o._fs_cached
@pytest.mark.parametrize("input_str", ("file:///tmp/foo", "s3://conn_id@bucket/test.txt"))
def test_str(self, input_str):
o = ObjectStoragePath(input_str)
assert str(o) == input_str