tests: Added test for bearer authorization in DownloadableFileSource
diff --git a/tests/sources/tar.py b/tests/sources/tar.py
index c7abbc6..8712c77 100644
--- a/tests/sources/tar.py
+++ b/tests/sources/tar.py
@@ -29,7 +29,7 @@
from buildstream._testing import generate_project, generate_element
from buildstream._testing import cli # pylint: disable=unused-import
from buildstream._testing._utils.site import HAVE_LZIP
-from tests.testutils.file_server import create_file_server
+from tests.testutils.file_server import create_file_server, create_bearer_http_server
from . import list_dir_contents
DATA_DIR = os.path.join(
@@ -355,6 +355,74 @@
assert checkout_contents == original_contents
+@pytest.mark.datafiles(os.path.join(DATA_DIR, "fetch"))
+def test_use_netrc_bearer_auth(cli, datafiles, tmpdir):
+ file_server_files = os.path.join(str(tmpdir), "file_server")
+ fake_home = os.path.join(str(tmpdir), "fake_home")
+ os.makedirs(file_server_files, exist_ok=True)
+ os.makedirs(fake_home, exist_ok=True)
+ project = str(datafiles)
+ checkoutdir = os.path.join(str(tmpdir), "checkout")
+
+ os.environ["HOME"] = fake_home
+ with open(os.path.join(fake_home, ".netrc"), "wb") as f:
+ os.fchmod(f.fileno(), 0o700)
+ f.write(b"machine 127.0.0.1\n")
+ f.write(b"password 12345\n")
+
+ #
+ # Enable using mirrors for source tracking
+ #
+ cli.configure({"track": {"source": "mirrors"}})
+
+ #
+ # Create a file server which uses bearer authentication
+ #
+ with create_bearer_http_server() as server:
+ server.set_directory(file_server_files)
+ server.add_token("12345")
+
+ #
+ # Configure the project to load our source mirror plugin which
+ # reports the "auth-header-format" extra data
+ #
+ additional_config = {
+ "aliases": {"tmpdir": server.base_url()},
+ "mirrors": [
+ {
+ "name": "middle-earth",
+ "kind": "bearermirror",
+ "aliases": {
+ "tmpdir": [server.base_url()],
+ },
+ },
+ ],
+ "plugins": [
+ {"origin": "local", "path": "sourcemirrors", "source-mirrors": ["bearermirror"]},
+ ],
+ }
+ generate_project(project, config=additional_config)
+
+ src_tar = os.path.join(file_server_files, "a.tar.gz")
+ _assemble_tar(os.path.join(str(datafiles), "content"), "a", src_tar)
+
+ server.start()
+
+ result = cli.run(project=project, args=["source", "track", "target.bst"])
+ result.assert_success()
+ result = cli.run(project=project, args=["source", "fetch", "target.bst"])
+ result.assert_success()
+ result = cli.run(project=project, args=["build", "target.bst"])
+ result.assert_success()
+ result = cli.run(project=project, args=["artifact", "checkout", "target.bst", "--directory", checkoutdir])
+ result.assert_success()
+
+ original_dir = os.path.join(str(datafiles), "content", "a")
+ original_contents = list_dir_contents(original_dir)
+ checkout_contents = list_dir_contents(checkoutdir)
+ assert checkout_contents == original_contents
+
+
@pytest.mark.parametrize("server_type", ("FTP", "HTTP"))
@pytest.mark.datafiles(os.path.join(DATA_DIR, "fetch"))
def test_netrc_already_specified_user(cli, datafiles, server_type, tmpdir):
diff --git a/tests/sources/tar/fetch/sourcemirrors/bearermirror.py b/tests/sources/tar/fetch/sourcemirrors/bearermirror.py
new file mode 100644
index 0000000..54d9305
--- /dev/null
+++ b/tests/sources/tar/fetch/sourcemirrors/bearermirror.py
@@ -0,0 +1,29 @@
+from typing import Optional, Dict, Any
+
+from buildstream import SourceMirror, MappingNode
+
+
+class Sample(SourceMirror):
+ BST_MIN_VERSION = "2.0"
+
+ def translate_url(
+ self,
+ *,
+ project_name: str,
+ alias: str,
+ alias_url: str,
+ alias_substitute_url: Optional[str],
+ source_url: str,
+ extra_data: Optional[Dict[str, Any]],
+ ) -> str:
+
+ if extra_data is not None:
+ extra_data["auth-header-format"] = "Bearer {password}"
+
+ return alias_substitute_url + source_url
+
+
+# Plugin entry point
+def setup():
+
+ return Sample
diff --git a/tests/testutils/bearer_http_server.py b/tests/testutils/bearer_http_server.py
new file mode 100644
index 0000000..6760983
--- /dev/null
+++ b/tests/testutils/bearer_http_server.py
@@ -0,0 +1,116 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import multiprocessing
+import os
+import posixpath
+import html
+from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus
+
+
+class Unauthorized(Exception):
+ pass
+
+
+class BearerRequestHandler(SimpleHTTPRequestHandler):
+ def get_root_dir(self):
+ authorization = self.headers.get("authorization")
+ if not authorization:
+ raise Unauthorized("unauthorized")
+
+ authorization = authorization.split()
+ if len(authorization) != 2 or authorization[0].lower() != "bearer":
+ raise Unauthorized("unauthorized")
+
+ token = authorization[1]
+ if token not in self.server.tokens:
+ raise Unauthorized("unauthorized")
+
+ return self.server.directory
+
+ def unauthorized(self):
+ shortmsg, longmsg = self.responses[HTTPStatus.UNAUTHORIZED]
+ self.send_response(HTTPStatus.UNAUTHORIZED, shortmsg)
+ self.send_header("Connection", "close")
+
+ content = self.error_message_format % {
+ "code": HTTPStatus.UNAUTHORIZED,
+ "message": html.escape(longmsg, quote=False),
+ "explain": html.escape(longmsg, quote=False),
+ }
+ body = content.encode("UTF-8", "replace")
+ self.send_header("Content-Type", self.error_content_type)
+ self.send_header("Content-Length", str(len(body)))
+ self.send_header("WWW-Authenticate", 'Bearer realm="{}"'.format(self.server.realm))
+ self.end_headers()
+ self.end_headers()
+
+ if self.command != "HEAD" and body:
+ self.wfile.write(body)
+
+ def do_GET(self):
+ try:
+ super().do_GET()
+ except Unauthorized:
+ self.unauthorized()
+
+ def do_HEAD(self):
+ try:
+ super().do_HEAD()
+ except Unauthorized:
+ self.unauthorized()
+
+ def translate_path(self, path):
+ path = path.split("?", 1)[0]
+ path = path.split("#", 1)[0]
+ path = posixpath.normpath(path)
+ assert posixpath.isabs(path)
+ path = posixpath.relpath(path, "/")
+ return os.path.join(self.get_root_dir(), path)
+
+
+class BearerHTTPServer(HTTPServer):
+ def __init__(self, *args, **kwargs):
+ self.tokens = set()
+ self.directory = None
+ self.realm = "Realm"
+ super().__init__(*args, **kwargs)
+
+
+class BearerHttpServer(multiprocessing.Process):
+ def __init__(self):
+ super().__init__()
+ self.server = BearerHTTPServer(("127.0.0.1", 0), BearerRequestHandler)
+ self.started = False
+
+ def start(self):
+ self.started = True
+ super().start()
+
+ def run(self):
+ self.server.serve_forever()
+
+ def stop(self):
+ if not self.started:
+ return
+ self.terminate()
+ self.join()
+
+ def set_directory(self, directory):
+ self.server.directory = directory
+
+ def add_token(self, token):
+ self.server.tokens.add(token)
+
+ def base_url(self):
+ return "http://127.0.0.1:{}".format(self.server.server_port)
diff --git a/tests/testutils/file_server.py b/tests/testutils/file_server.py
index ac1d6ec..1d43edb 100644
--- a/tests/testutils/file_server.py
+++ b/tests/testutils/file_server.py
@@ -15,6 +15,7 @@
from .ftp_server import SimpleFtpServer
from .http_server import SimpleHttpServer
+from .bearer_http_server import BearerHttpServer
@contextmanager
@@ -30,3 +31,19 @@
yield server
finally:
server.stop()
+
+
+#
+# We use a separate function here in order to avoid
+# confusing the linter (which thinks that anything
+# yielded by `create_file_server()` is a SimpleFtpServer).
+#
+# And no, type annotations with Union[...] does not fix this.
+#
+@contextmanager
+def create_bearer_http_server():
+ server = BearerHttpServer()
+ try:
+ yield server
+ finally:
+ server.stop()