blob: 7b3cd486f0e2cecba2cc02552a696a82525e5f2d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING
import jwt
import pytest
from freezegun import freeze_time
from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.serve_logs import create_app
from tests.test_utils.config import conf_vars
if TYPE_CHECKING:
from flask.testing import FlaskClient
LOG_DATA = "Airflow log data" * 20
@pytest.fixture
def client(tmpdir):
with conf_vars({('logging', 'base_log_folder'): str(tmpdir)}):
app = create_app()
yield app.test_client()
@pytest.fixture
def sample_log(tmpdir):
f = tmpdir / 'sample.log'
f.write(LOG_DATA.encode())
return f
@pytest.fixture
def signer(secret_key):
return JWTSigner(
secret_key=secret_key,
expiration_time_in_seconds=30,
audience="task-instance-logs",
)
@pytest.fixture
def different_audience(secret_key):
return JWTSigner(
secret_key=secret_key,
expiration_time_in_seconds=30,
audience="different-audience",
)
@pytest.mark.usefixtures('sample_log')
class TestServeLogs:
def test_forbidden_no_auth(self, client: FlaskClient):
assert 403 == client.get('/log/sample.log').status_code
def test_should_serve_file(self, client: FlaskClient, signer):
response = client.get(
'/log/sample.log',
headers={
'Authorization': signer.generate_signed_token({"filename": 'sample.log'}),
},
)
assert response.data.decode() == LOG_DATA
assert response.status_code == 200
def test_forbidden_different_logname(self, client: FlaskClient, signer):
response = client.get(
'/log/sample.log',
headers={
'Authorization': signer.generate_signed_token({"filename": 'different.log'}),
},
)
assert response.status_code == 403
def test_forbidden_expired(self, client: FlaskClient, signer):
with freeze_time("2010-01-14"):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 403
)
def test_forbidden_future(self, client: FlaskClient, signer):
with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 403
)
def test_ok_with_short_future_skew(self, client: FlaskClient, signer):
with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=1)):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 200
)
def test_ok_with_short_past_skew(self, client: FlaskClient, signer):
with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=31)):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 200
)
def test_forbidden_with_long_future_skew(self, client: FlaskClient, signer):
with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=10)):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 403
)
def test_forbidden_with_long_past_skew(self, client: FlaskClient, signer):
with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=40)):
token = signer.generate_signed_token({"filename": 'sample.log'})
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 403
)
def test_wrong_audience(self, client: FlaskClient, different_audience):
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': different_audience.generate_signed_token({"filename": 'sample.log'}),
},
).status_code
== 403
)
@pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"])
def test_missing_claims(self, claim_to_remove: str, client: FlaskClient, secret_key):
jwt_dict = {
"aud": "task-instance-logs",
"iat": datetime.datetime.utcnow(),
"nbf": datetime.datetime.utcnow(),
"exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=30),
}
del jwt_dict[claim_to_remove]
jwt_dict.update({"filename": 'sample.log'})
token = jwt.encode(
jwt_dict,
secret_key,
algorithm="HS512",
)
assert (
client.get(
'/log/sample.log',
headers={
'Authorization': token,
},
).status_code
== 403
)