blob: 0a19c312fba4eb8b7d7bb3aadeed30d6447698f9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import ast
import json
from unittest import mock
from airflow.models import Log
def client_with_login(app, expected_response_code=302, **kwargs):
patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash"
with mock.patch(patch_path) as check_password_hash:
check_password_hash.return_value = True
client = app.test_client()
resp = client.post("/login/", data=kwargs)
assert resp.status_code == expected_response_code
return client
def client_without_login(app):
# Anonymous users can only view if AUTH_ROLE_PUBLIC is set to non-Public
app.config["AUTH_ROLE_PUBLIC"] = "Viewer"
client = app.test_client()
return client
def client_without_login_as_admin(app):
# Anonymous users as Admin if set AUTH_ROLE_PUBLIC=Admin
app.config["AUTH_ROLE_PUBLIC"] = "Admin"
client = app.test_client()
return client
def check_content_in_response(text, resp, resp_code=200):
resp_html = resp.data.decode("utf-8")
assert resp_code == resp.status_code
if isinstance(text, list):
for line in text:
assert line in resp_html, f"Couldn't find {line!r}"
else:
assert text in resp_html, f"Couldn't find {text!r}"
def check_content_not_in_response(text, resp, resp_code=200):
resp_html = resp.data.decode("utf-8")
assert resp_code == resp.status_code
if isinstance(text, list):
for line in text:
assert line not in resp_html
else:
assert text not in resp_html
def _check_last_log(session, dag_id, event, execution_date, expected_extra=None):
logs = (
session.query(
Log.dag_id,
Log.task_id,
Log.event,
Log.execution_date,
Log.owner,
Log.extra,
)
.filter(
Log.dag_id == dag_id,
Log.event == event,
Log.execution_date == execution_date,
)
.order_by(Log.dttm.desc())
.limit(5)
.all()
)
assert len(logs) >= 1
assert logs[0].extra
if expected_extra:
assert json.loads(logs[0].extra) == expected_extra
session.query(Log).delete()
def _check_last_log_masked_connection(session, dag_id, event, execution_date):
logs = (
session.query(
Log.dag_id,
Log.task_id,
Log.event,
Log.execution_date,
Log.owner,
Log.extra,
)
.filter(
Log.dag_id == dag_id,
Log.event == event,
Log.execution_date == execution_date,
)
.order_by(Log.dttm.desc())
.limit(5)
.all()
)
assert len(logs) >= 1
extra = ast.literal_eval(logs[0].extra)
assert extra == {
"conn_id": "test_conn",
"conn_type": "http",
"description": "description",
"host": "localhost",
"port": "8080",
"username": "root",
"password": "***",
"extra": {"x_secret": "***", "y_secret": "***"},
}
def _check_last_log_masked_variable(session, dag_id, event, execution_date):
logs = (
session.query(
Log.dag_id,
Log.task_id,
Log.event,
Log.execution_date,
Log.owner,
Log.extra,
)
.filter(
Log.dag_id == dag_id,
Log.event == event,
Log.execution_date == execution_date,
)
.order_by(Log.dttm.desc())
.limit(5)
.all()
)
assert len(logs) >= 1
extra_dict = ast.literal_eval(logs[0].extra)
assert extra_dict == {"key": "x_secret", "val": "***"}