blob: ad2d6ebf7c35862b65f4d00a5aaebaf422bd2925 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
import jwt
import pytest
import time_machine
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
from airflow.utils import timezone
from airflow.utils.serve_logs import create_app
from tests_common.test_utils.config import conf_vars
if TYPE_CHECKING:
from flask.testing import FlaskClient
LOG_DATA = "Airflow log data" * 20
@pytest.fixture
def client_without_config(tmp_path):
with conf_vars(
{
("logging", "base_log_folder"): tmp_path.as_posix(),
("api_auth", "jwt_leeyway"): "0",
}
):
app = create_app()
yield app.test_client()
@pytest.fixture
def client_with_config():
with conf_vars(
{
(
"logging",
"logging_config_class",
): "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG",
("api_auth", "jwt_leeyway"): "0",
}
):
app = create_app()
yield app.test_client()
@pytest.fixture(params=["client_without_config", "client_with_config"])
def client(request):
return request.getfixturevalue(request.param)
@pytest.fixture
def sample_log(request, tmp_path):
client = request.getfixturevalue("client")
if client == request.getfixturevalue("client_without_config"):
base_log_dir = tmp_path
elif client == request.getfixturevalue("client_with_config"):
base_log_dir = Path(DEFAULT_LOGGING_CONFIG["handlers"]["task"]["base_log_folder"])
else:
raise ValueError(f"Unknown client fixture: {client}")
f = base_log_dir.joinpath("sample.log")
f.write_text(LOG_DATA)
return f
@pytest.fixture
def jwt_generator(secret_key):
return JWTGenerator(
secret_key=secret_key,
valid_for=5,
audience="task-instance-logs",
)
@pytest.fixture
def different_audience(secret_key):
return JWTGenerator(
secret_key=secret_key,
valid_for=30,
audience="different-audience",
)
@pytest.mark.usefixtures("sample_log")
class TestServeLogs:
def test_forbidden_no_auth(self, client: FlaskClient):
assert client.get("/log/sample.log").status_code == 403
def test_should_serve_file(self, client: FlaskClient, jwt_generator):
response = client.get(
"/log/sample.log",
headers={
"Authorization": jwt_generator.generate({"filename": "sample.log"}),
},
)
assert response.data.decode() == LOG_DATA
assert response.status_code == 200
def test_forbidden_different_logname(self, client: FlaskClient, jwt_generator):
response = client.get(
"/log/sample.log",
headers={
"Authorization": jwt_generator.generate({"filename": "different.log"}),
},
)
assert response.status_code == 403
def test_forbidden_expired(self, client: FlaskClient, jwt_generator):
with time_machine.travel("2010-01-14"):
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 403
)
def test_forbidden_future(self, client: FlaskClient, jwt_generator):
with time_machine.travel(timezone.utcnow() + timedelta(seconds=3600)):
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 403
)
def test_ok_with_short_future_skew(self, client: FlaskClient, jwt_generator):
print(f"Ts= {timezone.utcnow().timestamp()}")
with time_machine.travel(timezone.utcnow() + timedelta(seconds=1)):
print(f"Ts with travvel = {timezone.utcnow().timestamp()}")
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 200
)
def test_ok_with_short_past_skew(self, client: FlaskClient, jwt_generator):
with time_machine.travel(timezone.utcnow() - timedelta(seconds=31)):
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 200
)
def test_forbidden_with_long_future_skew(self, client: FlaskClient, jwt_generator):
with time_machine.travel(timezone.utcnow() + timedelta(seconds=40)):
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 403
)
def test_forbidden_with_long_past_skew(self, client: FlaskClient, jwt_generator):
with time_machine.travel(timezone.utcnow() - timedelta(seconds=40)):
token = jwt_generator.generate({"filename": "sample.log"})
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 403
)
def test_wrong_audience(self, client: FlaskClient, different_audience):
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": different_audience.generate({"filename": "sample.log"}),
},
).status_code
== 403
)
@pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"])
def test_missing_claims(self, claim_to_remove: str, client: FlaskClient, secret_key):
jwt_dict = {
"aud": "task-instance-logs",
"iat": timezone.utcnow(),
"nbf": timezone.utcnow(),
"exp": timezone.utcnow() + timedelta(seconds=30),
}
del jwt_dict[claim_to_remove]
jwt_dict.update({"filename": "sample.log"})
token = jwt.encode(
jwt_dict,
secret_key,
algorithm="HS512",
)
assert (
client.get(
"/log/sample.log",
headers={
"Authorization": token,
},
).status_code
== 403
)