blob: 45a5894ec6936323de59b509cc0d80875f4cdce3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import timedelta
from unittest import mock
import pytest
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.security import permissions
from airflow.utils.dates import parse_execution_date
from airflow.utils.session import create_session
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
class CustomXCom(BaseXCom):
@classmethod
def deserialize_value(cls, xcom: XCom):
return f"real deserialized {super().deserialize_value(xcom)}"
def orm_deserialize_value(self):
return f"orm deserialized {super().orm_deserialize_value()}"
@pytest.fixture(scope="module")
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
app, # type: ignore
username="test",
role_name="Test",
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
],
)
create_user(
app, # type: ignore
username="test_granular_permissions",
role_name="TestGranularDag",
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
],
)
app.appbuilder.sm.sync_perm_for_dag( # type: ignore
"test-dag-id-1",
access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
yield app
delete_user(app, username="test") # type: ignore
delete_user(app, username="test_no_permissions") # type: ignore
class TestXComEndpoint:
@staticmethod
def clean_db():
clear_db_dags()
clear_db_runs()
clear_db_xcom()
@pytest.fixture(autouse=True)
def setup_attrs(self, configured_app) -> None:
"""
Setup For XCom endpoint TC
"""
self.app = configured_app
self.client = self.app.test_client() # type:ignore
# clear existing xcoms
self.clean_db()
def teardown_method(self) -> None:
"""
Clear Hanging XComs
"""
self.clean_db()
class TestGetXComEntry(TestXComEndpoint):
def test_should_respond_200(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
xcom_key = "test-xcom-key"
execution_date_parsed = parse_execution_date(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
environ_overrides={"REMOTE_USER": "test"},
)
assert 200 == response.status_code
current_data = response.json
current_data["timestamp"] = "TIMESTAMP"
assert current_data == {
"dag_id": dag_id,
"execution_date": execution_date,
"key": xcom_key,
"task_id": task_id,
"timestamp": "TIMESTAMP",
"value": "TEST_VALUE",
}
def test_should_raises_401_unauthenticated(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
xcom_key = "test-xcom-key"
execution_date_parsed = parse_execution_date(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}"
)
assert_401(response)
def test_should_raise_403_forbidden(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
xcom_key = "test-xcom-key"
execution_date_parsed = parse_execution_date(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403
def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key, *, backend=XCom):
with create_session() as session:
dagrun = DagRun(
dag_id=dag_id,
run_id=run_id,
execution_date=execution_date,
start_date=execution_date,
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
backend.set(
key=xcom_key,
value="TEST_VALUE",
run_id=run_id,
task_id=task_id,
dag_id=dag_id,
)
@pytest.mark.parametrize(
"query, expected_value",
[
pytest.param("?deserialize=true", "real deserialized TEST_VALUE", id="true"),
pytest.param("?deserialize=false", "orm deserialized TEST_VALUE", id="false"),
pytest.param("", "orm deserialized TEST_VALUE", id="default"),
],
)
@conf_vars({("core", "xcom_backend"): "tests.api_connexion.endpoints.test_xcom_endpoint.CustomXCom"})
def test_custom_xcom_deserialize(self, query, expected_value):
XCom = resolve_xcom_backend()
self._create_xcom_entry("dag", "run", utcnow(), "task", "key", backend=XCom)
url = f"/api/v1/dags/dag/dagRuns/run/taskInstances/task/xcomEntries/key{query}"
with mock.patch("airflow.api_connexion.endpoints.xcom_endpoint.XCom", XCom):
response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
assert response.status_code == 200
assert response.json["value"] == expected_value
class TestGetXComEntries(TestXComEndpoint):
def test_should_respond_200(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
execution_date_parsed = parse_execution_date(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id, run_id, execution_date_parsed, task_id)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries",
environ_overrides={"REMOTE_USER": "test"},
)
assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data["xcom_entries"]:
xcom_entry["timestamp"] = "TIMESTAMP"
assert response_data == {
"xcom_entries": [
{
"dag_id": dag_id,
"execution_date": execution_date,
"key": "test-xcom-key-1",
"task_id": task_id,
"timestamp": "TIMESTAMP",
},
{
"dag_id": dag_id,
"execution_date": execution_date,
"key": "test-xcom-key-2",
"task_id": task_id,
"timestamp": "TIMESTAMP",
},
],
"total_entries": 2,
}
def test_should_respond_200_with_tilde_and_access_to_all_dags(self):
dag_id_1 = "test-dag-id-1"
task_id_1 = "test-task-id-1"
execution_date = "2005-04-02T00:00:00+00:00"
execution_date_parsed = parse_execution_date(execution_date)
run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id_1, run_id_1, execution_date_parsed, task_id_1)
dag_id_2 = "test-dag-id-2"
task_id_2 = "test-task-id-2"
run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2)
response = self.client.get(
"/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
environ_overrides={"REMOTE_USER": "test"},
)
assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data["xcom_entries"]:
xcom_entry["timestamp"] = "TIMESTAMP"
assert response_data == {
"xcom_entries": [
{
"dag_id": dag_id_1,
"execution_date": execution_date,
"key": "test-xcom-key-1",
"task_id": task_id_1,
"timestamp": "TIMESTAMP",
},
{
"dag_id": dag_id_1,
"execution_date": execution_date,
"key": "test-xcom-key-2",
"task_id": task_id_1,
"timestamp": "TIMESTAMP",
},
{
"dag_id": dag_id_2,
"execution_date": execution_date,
"key": "test-xcom-key-1",
"task_id": task_id_2,
"timestamp": "TIMESTAMP",
},
{
"dag_id": dag_id_2,
"execution_date": execution_date,
"key": "test-xcom-key-2",
"task_id": task_id_2,
"timestamp": "TIMESTAMP",
},
],
"total_entries": 4,
}
def test_should_respond_200_with_tilde_and_granular_dag_access(self):
dag_id_1 = "test-dag-id-1"
task_id_1 = "test-task-id-1"
execution_date = "2005-04-02T00:00:00+00:00"
execution_date_parsed = parse_execution_date(execution_date)
dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1)
dag_id_2 = "test-dag-id-2"
task_id_2 = "test-task-id-2"
run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2)
self._create_invalid_xcom_entries(execution_date_parsed)
response = self.client.get(
"/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
environ_overrides={"REMOTE_USER": "test_granular_permissions"},
)
assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data["xcom_entries"]:
xcom_entry["timestamp"] = "TIMESTAMP"
assert response_data == {
"xcom_entries": [
{
"dag_id": dag_id_1,
"execution_date": execution_date,
"key": "test-xcom-key-1",
"task_id": task_id_1,
"timestamp": "TIMESTAMP",
},
{
"dag_id": dag_id_1,
"execution_date": execution_date,
"key": "test-xcom-key-2",
"task_id": task_id_1,
"timestamp": "TIMESTAMP",
},
],
"total_entries": 2,
}
def test_should_raises_401_unauthenticated(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
execution_date_parsed = parse_execution_date(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
self._create_xcom_entries(dag_id, run_id, execution_date_parsed, task_id)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries"
)
assert_401(response)
def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id):
with create_session() as session:
dag = DagModel(dag_id=dag_id)
session.add(dag)
dagrun = DagRun(
dag_id=dag_id,
run_id=run_id,
execution_date=execution_date,
start_date=execution_date,
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
for i in [1, 2]:
XCom.set(
key=f"test-xcom-key-{i}",
value="TEST",
run_id=run_id,
task_id=task_id,
dag_id=dag_id,
)
def _create_invalid_xcom_entries(self, execution_date):
"""
Invalid XCom entries to test join query
"""
with create_session() as session:
dag = DagModel(dag_id="invalid_dag")
session.add(dag)
dagrun = DagRun(
dag_id="invalid_dag",
run_id="invalid_run_id",
execution_date=execution_date + timedelta(days=1),
start_date=execution_date,
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
dagrun1 = DagRun(
dag_id="invalid_dag",
run_id="not_this_run_id",
execution_date=execution_date,
start_date=execution_date,
run_type=DagRunType.MANUAL,
)
session.add(dagrun1)
ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
ti.dag_id = "invalid_dag"
session.add(ti)
for i in [1, 2]:
XCom.set(
key=f"invalid-xcom-key-{i}",
value="TEST",
run_id="not_this_run_id",
task_id="invalid_task",
dag_id="invalid_dag",
)
class TestPaginationGetXComEntries(TestXComEndpoint):
def setup_method(self):
self.dag_id = "test-dag-id"
self.task_id = "test-task-id"
self.execution_date = "2005-04-02T00:00:00+00:00"
self.execution_date_parsed = parse_execution_date(self.execution_date)
self.run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
@pytest.mark.parametrize(
"query_params, expected_xcom_ids",
[
(
"limit=1",
["TEST_XCOM_KEY1"],
),
(
"limit=2",
["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"],
),
(
"offset=5",
[
"TEST_XCOM_KEY5",
"TEST_XCOM_KEY6",
"TEST_XCOM_KEY7",
"TEST_XCOM_KEY8",
"TEST_XCOM_KEY9",
],
),
(
"offset=0",
[
"TEST_XCOM_KEY1",
"TEST_XCOM_KEY10",
"TEST_XCOM_KEY2",
"TEST_XCOM_KEY3",
"TEST_XCOM_KEY4",
"TEST_XCOM_KEY5",
"TEST_XCOM_KEY6",
"TEST_XCOM_KEY7",
"TEST_XCOM_KEY8",
"TEST_XCOM_KEY9",
],
),
(
"limit=1&offset=5",
["TEST_XCOM_KEY5"],
),
(
"limit=1&offset=1",
["TEST_XCOM_KEY10"],
),
(
"limit=2&offset=2",
["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],
),
],
)
def test_handle_limit_offset(self, query_params, expected_xcom_ids):
url = (
f"/api/v1/dags/{self.dag_id}/dagRuns/{self.run_id}/taskInstances/{self.task_id}/xcomEntries"
f"?{query_params}"
)
with create_session() as session:
dagrun = DagRun(
dag_id=self.dag_id,
run_id=self.run_id,
execution_date=self.execution_date_parsed,
start_date=self.execution_date_parsed,
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=self.task_id), run_id=self.run_id)
ti.dag_id = self.dag_id
session.add(ti)
with create_session() as session:
for i in range(1, 11):
xcom = XCom(
dag_run_id=dagrun.id,
key=f"TEST_XCOM_KEY{i}",
value=b"null",
run_id=self.run_id,
task_id=self.task_id,
dag_id=self.dag_id,
timestamp=self.execution_date_parsed,
)
session.add(xcom)
response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
assert response.status_code == 200
assert response.json["total_entries"] == 10
conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn]
assert conn_ids == expected_xcom_ids