blob: cbd7c80bb0178879630e5b571c0511d5a83a72d6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import pickle
from datetime import datetime
from typing import TYPE_CHECKING
from unittest import mock
import httpx
import pytest
import uuid6
from task_sdk import make_client, make_client_w_dry_run, make_client_w_responses
from uuid6 import uuid7
from airflow.sdk import timezone
from airflow.sdk.api.client import RemoteValidationError, ServerResponseError
from airflow.sdk.api.datamodels._generated import (
AssetEventsResponse,
AssetResponse,
ConnectionResponse,
DagRunState,
DagRunStateResponse,
HITLDetailResponse,
HITLUser,
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
DeferTask,
ErrorResponse,
HITLDetailRequestResult,
OKResponse,
PreviousDagRunResult,
RescheduleTask,
TaskRescheduleStartDate,
)
from airflow.utils.state import TerminalTIState
if TYPE_CHECKING:
from time_machine import TimeMachineFixture
class TestClient:
@pytest.mark.parametrize(
["path", "json_response"],
[
(
"/task-instances/1/run",
{
"dag_run": {
"dag_id": "test_dag",
"run_id": "test_run",
"logical_date": "2021-01-01T00:00:00Z",
"start_date": "2021-01-01T00:00:00Z",
"run_type": "manual",
"run_after": "2021-01-01T00:00:00Z",
"consumed_asset_events": [],
},
"max_tries": 0,
"should_retry": False,
},
),
],
)
def test_dry_run(self, path, json_response):
client = make_client_w_dry_run()
assert client.base_url == "dry-run://server"
resp = client.get(path)
assert resp.status_code == 200
assert resp.json() == json_response
@mock.patch("airflow.sdk.api.client.API_SSL_CERT_PATH", "/capath/does/not/exist/")
def test_add_capath(self):
def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=200)
with pytest.raises(FileNotFoundError) as err:
make_client(httpx.MockTransport(handle_request))
assert isinstance(err.value, FileNotFoundError)
def test_error_parsing(self):
responses = [
httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})
]
client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert isinstance(err.value, ServerResponseError)
assert isinstance(err.value.detail, list)
assert err.value.detail == [
RemoteValidationError(loc=["#0"], msg="err", type="required"),
]
def test_error_parsing_plain_text(self):
responses = [httpx.Response(422, content=b"Internal Server Error")]
client = make_client_w_responses(responses)
with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
def test_error_parsing_other_json(self):
responses = [httpx.Response(404, json={"detail": "Not found"})]
client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None
def test_server_response_error_pickling(self):
responses = [httpx.Response(404, json={"detail": {"message": "Invalid input"}})]
client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as exc_info:
client.get("http://error")
err = exc_info.value
assert err.args == ("Server returned error",)
assert err.detail == {"detail": {"message": "Invalid input"}}
# Check that the error is picklable
pickled = pickle.dumps(err)
unpickled = pickle.loads(pickled)
assert isinstance(unpickled, ServerResponseError)
# Test that unpickled error has the same attributes as the original
assert unpickled.response.json() == {"detail": {"message": "Invalid input"}}
assert unpickled.detail == {"detail": {"message": "Invalid input"}}
assert unpickled.response.status_code == 404
assert unpickled.request.url == "http://error"
@mock.patch("time.sleep", return_value=None)
def test_retry_handling_unrecoverable_error(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 6,
httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)
with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
assert len(responses) == 3
assert mock_sleep.call_count == 4
@mock.patch("time.sleep", return_value=None)
def test_retry_handling_recovered(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 2,
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)
response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 2
@mock.patch("time.sleep", return_value=None)
def test_retry_handling_overload(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(429, text="I am really busy atm, please back-off", headers={"Retry-After": "37"}),
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)
response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 1
assert mock_sleep.call_args[0][0] == 37
@mock.patch("time.sleep", return_value=None)
def test_retry_handling_non_retry_error(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(422, json={"detail": "Somehow this is a bad request"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)
with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert len(responses) == 1
assert mock_sleep.call_count == 0
assert err.value.args == ("Somehow this is a bad request",)
@mock.patch("time.sleep", return_value=None)
def test_retry_handling_ok(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)
response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 0
def test_token_renewal(self):
responses: list[httpx.Response] = [
httpx.Response(200, json={"ok": "1"}),
httpx.Response(404, json={"var": "not_found"}, headers={"Refreshed-API-Token": "abc"}),
httpx.Response(200, json={"ok": "3"}),
]
client = make_client_w_responses(responses)
response = client.get("/")
assert response.status_code == 200
assert client.auth is not None
assert not client.auth.token
with pytest.raises(ServerResponseError):
response = client.get("/")
# Even thought it was Not Found, we should still respect the header
assert client.auth is not None
assert client.auth.token == "abc"
# Test that the next request is made with that new auth token
response = client.get("/")
assert response.status_code == 200
assert response.request.headers["Authorization"] == "Bearer abc"
@pytest.mark.parametrize(
["status_code", "description"],
[
(399, "status code < 400"),
(301, "3xx redirect status code"),
(600, "status code >= 600"),
],
)
def test_server_response_error_invalid_status_codes(self, status_code, description):
"""Test that ServerResponseError.from_response returns None for invalid status codes."""
response = httpx.Response(status_code, json={"detail": f"Test {description}"})
assert ServerResponseError.from_response(response) is None
class TestTaskInstanceOperations:
"""
Test that the TestVariableOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for task instances including endpoint and
response parsing.
"""
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_task_instance_start(self, mock_sleep, make_ti_context):
# Simulate a successful response from the server that starts a task
ti_id = uuid6.uuid7()
start_date = "2024-10-31T12:00:00Z"
ti_context = make_ti_context(
start_date=start_date,
logical_date="2024-10-31T12:00:00Z",
run_type="manual",
)
# ...including a validation that retry really works
call_count = 0
def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 3:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == f"/task-instances/{ti_id}/run":
actual_body = json.loads(request.read())
assert actual_body["pid"] == 100
assert actual_body["start_date"] == start_date
assert actual_body["state"] == "running"
return httpx.Response(
status_code=200,
json=ti_context.model_dump(mode="json"),
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
resp = client.task_instances.start(ti_id, 100, start_date)
assert resp == ti_context
assert call_count == 3
@pytest.mark.parametrize(
"state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS]
)
def test_task_instance_finish(self, state):
# Simulate a successful response from the server that finishes (moved to terminal state) a task
ti_id = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["state"] == state
assert actual_body["rendered_map_index"] == "test"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(
ti_id, state=state, when="2024-10-31T12:00:00Z", rendered_map_index="test"
)
def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a heartbeat for a ti
ti_id = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
actual_body = json.loads(request.read())
assert actual_body["pid"] == 100
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.heartbeat(ti_id, 100)
def test_task_instance_defer(self):
# Simulate a successful response from the server that defers a task
ti_id = uuid6.uuid7()
msg = DeferTask(
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
next_method="execute_complete",
trigger_kwargs={
"__type": "dict",
"__var": {
"moment": {"__type": "datetime", "__var": 1730982899.0},
"end_from_trigger": False,
},
},
next_kwargs={"__type": "dict", "__var": {}},
)
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "deferred"
assert actual_body["trigger_kwargs"] == msg.trigger_kwargs
assert (
actual_body["classpath"] == "airflow.providers.standard.triggers.temporal.DateTimeTrigger"
)
assert actual_body["next_method"] == "execute_complete"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.defer(ti_id, msg)
def test_task_instance_reschedule(self):
# Simulate a successful response from the server that reschedules a task
ti_id = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "up_for_reschedule"
assert actual_body["reschedule_date"] == "2024-10-31T12:00:00Z"
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
msg = RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
)
client.task_instances.reschedule(ti_id, msg)
def test_task_instance_up_for_retry(self):
ti_id = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "up_for_retry"
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["rendered_map_index"] == "test"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.retry(
ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test"
)
@pytest.mark.parametrize(
"rendered_fields",
[
pytest.param({"field1": "rendered_value1", "field2": "rendered_value2"}, id="simple-rendering"),
pytest.param(
{
"field1": "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
"{'att1': 'test', 'att2': 'test2'), "
"'nested2': ClassWithCustomAttributes("
"{'att3': 'test3', 'att4': 'test4')"
},
id="complex-rendering",
),
],
)
def test_taskinstance_set_rtif_success(self, rendered_fields):
TI_ID = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{TI_ID}/rtif":
return httpx.Response(
status_code=201,
json={"message": "Rendered task instance fields successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.set_rtif(id=TI_ID, body=rendered_fields)
assert result == OKResponse(ok=True)
def test_get_count_basic(self):
"""Test basic get_count functionality with just dag_id."""
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/task-instances/count"
assert request.url.params.get("dag_id") == "test_dag"
return httpx.Response(200, json=5)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.get_count(dag_id="test_dag")
assert result.count == 5
def test_get_count_with_all_params(self):
"""Test get_count with all optional parameters."""
logical_dates_str = ["2024-01-01T00:00:00+00:00", "2024-01-02T00:00:00+00:00"]
logical_dates = [timezone.parse(d) for d in logical_dates_str]
task_ids = ["task1", "task2"]
states = ["success", "failed"]
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/task-instances/count"
assert request.method == "GET"
params = request.url.params
assert params["dag_id"] == "test_dag"
assert params.get_list("task_ids") == task_ids
assert params["task_group_id"] == "group1"
assert params.get_list("logical_dates") == logical_dates_str
assert params.get_list("run_ids") == []
assert params.get_list("states") == states
assert params["map_index"] == "0"
return httpx.Response(200, json=10)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.get_count(
dag_id="test_dag",
map_index=0,
task_ids=task_ids,
task_group_id="group1",
logical_dates=logical_dates,
states=states,
)
assert result.count == 10
def test_get_task_states_basic(self):
"""Test basic get_task_states functionality with just dag_id."""
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/task-instances/states"
assert request.url.params.get("dag_id") == "test_dag"
assert request.url.params.get("task_group_id") == "group1"
return httpx.Response(
200, json={"task_states": {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}}
)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.get_task_states(dag_id="test_dag", task_group_id="group1")
assert result.task_states == {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}
def test_get_task_states_with_all_params(self):
"""Test get_task_states with all optional parameters."""
logical_dates_str = ["2024-01-01T00:00:00+00:00", "2024-01-02T00:00:00+00:00"]
logical_dates = [timezone.parse(d) for d in logical_dates_str]
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/task-instances/states"
assert request.method == "GET"
params = request.url.params
assert params["dag_id"] == "test_dag"
assert params["task_group_id"] == "group1"
assert params.get_list("logical_dates") == logical_dates_str
assert params.get_list("task_ids") == []
assert params.get_list("run_ids") == []
assert params.get("map_index") == "0"
return httpx.Response(
200, json={"task_states": {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}}
)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.get_task_states(
dag_id="test_dag",
map_index=0,
task_group_id="group1",
logical_dates=logical_dates,
)
assert result.task_states == {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}
class TestVariableOperations:
"""
Test that the VariableOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for variables including endpoint and
response parsing.
"""
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_variable_get_success(self, mock_sleep):
# Simulate a successful response from the server with a variable
# ...including a validation that retry really works
call_count = 0
def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 2:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
json={"key": "test_key", "value": "test_value"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.variables.get(key="test_key")
assert isinstance(result, VariableResponse)
assert result.key == "test_key"
assert result.value == "test_value"
assert call_count == 2
def test_variable_not_found(self):
# Simulate a 404 response from the server
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/non_existent_var":
return httpx.Response(
status_code=404,
json={
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
resp = client.variables.get(key="non_existent_var")
assert isinstance(resp, ErrorResponse)
assert resp.error == ErrorType.VARIABLE_NOT_FOUND
assert resp.detail == {"key": "non_existent_var"}
@mock.patch("time.sleep", return_value=None)
def test_variable_get_500_error(self, mock_sleep):
# Simulate a response from the server returning a 500 error
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=500,
headers=[("content-Type", "application/json")],
json={
"reason": "internal_server_error",
"message": "Internal Server Error",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError):
client.variables.get(
key="test_key",
)
def test_variable_set_success(self):
# Simulate a successful response from the server when putting a variable
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=201,
json={"message": "Variable successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.variables.set(key="test_key", value="test_value", description="test_description")
assert result == OKResponse(ok=True)
def test_variable_delete_success(self):
# Simulate a successful response from the server when deleting a variable
def handle_request(request: httpx.Request) -> httpx.Response:
if request.method == "DELETE" and request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
json={"count": 1},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.variables.delete(key="test_key")
assert result == OKResponse(ok=True)
class TestXCOMOperations:
"""
Test that the XComOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for xcoms including endpoint and
response parsing.
"""
@pytest.mark.parametrize(
"value",
[
pytest.param("value1", id="string-value"),
pytest.param({"key1": "value1"}, id="dict-value"),
pytest.param('{"key1": "value1"}', id="dict-str-value"),
pytest.param(["value1", "value2"], id="list-value"),
pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"),
],
)
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_xcom_get_success(self, mock_sleep, value):
# Simulate a successful response from the server when getting an xcom
# ...including a validation that retry really works
call_count = 0
def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 3:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=201,
json={"key": "test_key", "value": value},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.get(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
)
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == value
assert call_count == 3
def test_xcom_get_success_with_map_index(self):
# Simulate a successful response from the server when getting an xcom with map_index passed
def handle_request(request: httpx.Request) -> httpx.Response:
if (
request.url.path == "/xcoms/dag_id/run_id/task_id/key"
and request.url.params.get("map_index") == "2"
):
return httpx.Response(
status_code=201,
json={"key": "test_key", "value": "test_value"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.get(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
map_index=2,
)
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == "test_value"
def test_xcom_get_success_with_include_prior_dates(self):
# Simulate a successful response from the server when getting an xcom with include_prior_dates passed
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key" and request.url.params.get(
"include_prior_dates"
):
return httpx.Response(
status_code=201,
json={"key": "test_key", "value": "test_value"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.get(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
include_prior_dates=True,
)
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == "test_value"
@mock.patch("time.sleep", return_value=None)
def test_xcom_get_500_error(self, mock_sleep):
# Simulate a successful response from the server returning a 500 error
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=500,
headers=[("content-Type", "application/json")],
json={
"reason": "invalid_format",
"message": "XCom value is not a valid JSON",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError):
client.xcoms.get(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
)
@pytest.mark.parametrize(
"values",
[
pytest.param("value1", id="string-value"),
pytest.param({"key1": "value1"}, id="dict-value"),
pytest.param(["value1", "value2"], id="list-value"),
pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"),
],
)
def test_xcom_set_success(self, values):
# Simulate a successful response from the server when setting an xcom
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
assert json.loads(request.read()) == values
return httpx.Response(
status_code=201,
json={"message": "XCom successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.set(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
value=values,
)
assert result == OKResponse(ok=True)
def test_xcom_set_with_map_index(self):
# Simulate a successful response from the server when setting an xcom with map_index passed
def handle_request(request: httpx.Request) -> httpx.Response:
if (
request.url.path == "/xcoms/dag_id/run_id/task_id/key"
and request.url.params.get("map_index") == "2"
):
assert json.loads(request.read()) == "value1"
return httpx.Response(
status_code=201,
json={"message": "XCom successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.set(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
value="value1",
map_index=2,
)
assert result == OKResponse(ok=True)
def test_xcom_set_with_mapped_length(self):
# Simulate a successful response from the server when setting an xcom with mapped_length
def handle_request(request: httpx.Request) -> httpx.Response:
if (
request.url.path == "/xcoms/dag_id/run_id/task_id/key"
and request.url.params.get("map_index") == "2"
and request.url.params.get("mapped_length") == "3"
):
assert json.loads(request.read()) == "value1"
return httpx.Response(
status_code=201,
json={"message": "XCom successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.set(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
value="value1",
map_index=2,
mapped_length=3,
)
assert result == OKResponse(ok=True)
class TestConnectionOperations:
"""
Test that the TestConnectionOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for connections including endpoint and
response parsing.
"""
def test_connection_get_success(self):
# Simulate a successful response from the server with a connection
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/connections/test_conn":
return httpx.Response(
status_code=200,
json={
"conn_id": "test_conn",
"conn_type": "mysql",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.connections.get(conn_id="test_conn")
assert isinstance(result, ConnectionResponse)
assert result.conn_id == "test_conn"
assert result.conn_type == "mysql"
def test_connection_get_404_not_found(self):
# Simulate a successful response from the server with a connection
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/connections/test_conn":
return httpx.Response(
status_code=404,
json={
"reason": "not_found",
"message": "Connection with ID test_conn not found",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.connections.get(conn_id="test_conn")
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.CONNECTION_NOT_FOUND
class TestAssetEventOperations:
@pytest.mark.parametrize(
"request_params",
[
({"name": "this_asset", "uri": "s3://bucket/key"}),
({"alias_name": "this_asset_alias"}),
],
)
def test_by_name_get_success(self, request_params):
def handle_request(request: httpx.Request) -> httpx.Response:
params = request.url.params
if request.url.path == "/asset-events/by-asset":
assert params.get("name") == request_params.get("name")
assert params.get("uri") == request_params.get("uri")
elif request.url.path == "/asset-events/by-asset-alias":
assert params.get("name") == request_params.get("alias_name")
else:
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
return httpx.Response(
status_code=200,
json={
"asset_events": [
{
"id": 1,
"asset": {
"name": "this_asset",
"uri": "s3://bucket/key",
"group": "asset",
},
"created_dagruns": [],
"timestamp": "2023-01-01T00:00:00Z",
}
]
},
)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.asset_events.get(**request_params)
assert isinstance(result, AssetEventsResponse)
assert len(result.asset_events) == 1
assert result.asset_events[0].asset.name == "this_asset"
assert result.asset_events[0].asset.uri == "s3://bucket/key"
class TestAssetOperations:
@pytest.mark.parametrize(
"request_params",
[
({"name": "this_asset"}),
({"uri": "s3://bucket/key"}),
],
)
def test_by_name_get_success(self, request_params):
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path in ("/assets/by-name", "/assets/by-uri"):
return httpx.Response(
status_code=200,
json={
"name": "this_asset",
"uri": "s3://bucket/key",
"group": "asset",
"extra": {"foo": "bar"},
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.assets.get(**request_params)
assert isinstance(result, AssetResponse)
assert result.name == "this_asset"
assert result.uri == "s3://bucket/key"
@pytest.mark.parametrize(
"request_params",
[
({"name": "this_asset"}),
({"uri": "s3://bucket/key"}),
],
)
def test_by_name_get_404_not_found(self, request_params):
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path in ("/assets/by-name", "/assets/by-uri"):
return httpx.Response(
status_code=404,
json={
"detail": {
"message": "Asset with name non_existent not found",
"reason": "not_found",
}
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.assets.get(**request_params)
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.ASSET_NOT_FOUND
class TestDagRunOperations:
def test_trigger(self):
# Simulate a successful response from the server when triggering a dag run
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_trigger/test_run_id":
actual_body = json.loads(request.read())
assert actual_body["logical_date"] == "2025-01-01T00:00:00Z"
assert actual_body["conf"] == {}
# Since the value for `reset_dag_run` is default, it should not be present in the request body
assert "reset_dag_run" not in actual_body
return httpx.Response(status_code=204)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.trigger(
dag_id="test_trigger", run_id="test_run_id", logical_date=timezone.datetime(2025, 1, 1)
)
assert result == OKResponse(ok=True)
def test_trigger_conflict(self):
"""Test that if the dag run already exists, the client returns an error when default reset_dag_run=False"""
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_trigger_conflict/test_run_id":
return httpx.Response(
status_code=409,
json={
"detail": {
"reason": "already_exists",
"message": "A Dag Run already exists for Dag test_trigger_conflict with run id test_run_id",
}
},
)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.trigger(dag_id="test_trigger_conflict", run_id="test_run_id")
assert result == ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS)
def test_trigger_conflict_reset_dag_run(self):
"""Test that if dag run already exists and reset_dag_run=True, the client clears the dag run"""
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_trigger_conflict_reset/test_run_id":
return httpx.Response(
status_code=409,
json={
"detail": {
"reason": "already_exists",
"message": "A Dag Run already exists for Dag test_trigger_conflict with run id test_run_id",
}
},
)
if request.url.path == "/dag-runs/test_trigger_conflict_reset/test_run_id/clear":
return httpx.Response(status_code=204)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.trigger(
dag_id="test_trigger_conflict_reset",
run_id="test_run_id",
reset_dag_run=True,
)
assert result == OKResponse(ok=True)
def test_clear(self):
"""Test that the client can clear a dag run"""
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_clear/test_run_id/clear":
return httpx.Response(status_code=204)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.clear(dag_id="test_clear", run_id="test_run_id")
assert result == OKResponse(ok=True)
def test_get_state(self):
"""Test that the client can get the state of a dag run"""
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_state/test_run_id/state":
return httpx.Response(
status_code=200,
json={"state": "running"},
)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_state(dag_id="test_state", run_id="test_run_id")
assert result == DagRunStateResponse(state=DagRunState.RUNNING)
def test_get_count_basic(self):
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/count":
assert request.url.params["dag_id"] == "test_dag"
return httpx.Response(status_code=200, json=1)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_count(dag_id="test_dag")
assert result.count == 1
def test_get_count_with_states(self):
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/count":
assert request.url.params["dag_id"] == "test_dag"
assert request.url.params.get_list("states") == ["success", "failed"]
return httpx.Response(status_code=200, json=2)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_count(dag_id="test_dag", states=["success", "failed"])
assert result.count == 2
def test_get_count_with_logical_dates(self):
logical_dates = [timezone.datetime(2025, 1, 1), timezone.datetime(2025, 1, 2)]
logical_dates_str = [d.isoformat() for d in logical_dates]
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/count":
assert request.url.params["dag_id"] == "test_dag"
assert request.url.params.get_list("logical_dates") == logical_dates_str
return httpx.Response(status_code=200, json=2)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_count(
dag_id="test_dag", logical_dates=[timezone.datetime(2025, 1, 1), timezone.datetime(2025, 1, 2)]
)
assert result.count == 2
def test_get_count_with_run_ids(self):
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/count":
assert request.url.params["dag_id"] == "test_dag"
assert request.url.params.get_list("run_ids") == ["run1", "run2"]
return httpx.Response(status_code=200, json=2)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", "run2"])
assert result.count == 2
def test_get_previous_basic(self):
"""Test basic get_previous functionality with dag_id and logical_date."""
logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_dag/previous":
assert request.url.params["logical_date"] == logical_date.isoformat()
# Return complete DagRun data
return httpx.Response(
status_code=200,
json={
"dag_id": "test_dag",
"run_id": "prev_run",
"logical_date": "2024-01-14T12:00:00+00:00",
"start_date": "2024-01-14T12:05:00+00:00",
"run_after": "2024-01-14T12:00:00+00:00",
"run_type": "scheduled",
"state": "success",
"consumed_asset_events": [],
},
)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date)
assert isinstance(result, PreviousDagRunResult)
assert result.dag_run.dag_id == "test_dag"
assert result.dag_run.run_id == "prev_run"
assert result.dag_run.state == "success"
def test_get_previous_with_state_filter(self):
"""Test get_previous functionality with state filtering."""
logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_dag/previous":
assert request.url.params["logical_date"] == logical_date.isoformat()
assert request.url.params["state"] == "success"
# Return complete DagRun data
return httpx.Response(
status_code=200,
json={
"dag_id": "test_dag",
"run_id": "prev_success_run",
"logical_date": "2024-01-14T12:00:00+00:00",
"start_date": "2024-01-14T12:05:00+00:00",
"run_after": "2024-01-14T12:00:00+00:00",
"run_type": "scheduled",
"state": "success",
"consumed_asset_events": [],
},
)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date, state="success")
assert isinstance(result, PreviousDagRunResult)
assert result.dag_run.dag_id == "test_dag"
assert result.dag_run.run_id == "prev_success_run"
assert result.dag_run.state == "success"
def test_get_previous_not_found(self):
"""Test get_previous when no previous Dag run exists returns None."""
logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dag-runs/test_dag/previous":
assert request.url.params["logical_date"] == logical_date.isoformat()
# Return None (null) when no previous Dag run found
return httpx.Response(status_code=200, content="null")
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date)
assert isinstance(result, PreviousDagRunResult)
assert result.dag_run is None
class TestTaskRescheduleOperations:
def test_get_start_date(self):
"""Test that the client can get the start date of a task reschedule"""
ti_id = uuid6.uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-reschedules/{ti_id}/start_date":
assert request.url.params.get("try_number") == "1"
return httpx.Response(
status_code=200,
json="2024-01-01T00:00:00Z",
)
return httpx.Response(status_code=422)
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.get_reschedule_start_date(id=ti_id, try_number=1)
assert isinstance(result, TaskRescheduleStartDate)
assert result.start_date == "2024-01-01T00:00:00Z"
class TestHITLOperations:
def test_add_response(self) -> None:
ti_id = uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path in (f"/hitlDetails/{ti_id}"):
return httpx.Response(
status_code=201,
json={
"ti_id": str(ti_id),
"options": ["Approval", "Reject"],
"subject": "This is subject",
"body": "This is body",
"defaults": ["Approval"],
"params": None,
"multiple": False,
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.hitl.add_response(
ti_id=ti_id,
options=["Approval", "Reject"],
subject="This is subject",
body="This is body",
defaults=["Approval"],
params=None,
multiple=False,
)
assert isinstance(result, HITLDetailRequestResult)
assert result.ti_id == ti_id
assert result.options == ["Approval", "Reject"]
assert result.subject == "This is subject"
assert result.body == "This is body"
assert result.defaults == ["Approval"]
assert result.params is None
assert result.multiple is False
assert result.assigned_users is None
def test_update_response(self, time_machine: TimeMachineFixture) -> None:
time_machine.move_to(datetime(2025, 7, 3, 0, 0, 0))
ti_id = uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path in (f"/hitlDetails/{ti_id}"):
return httpx.Response(
status_code=200,
json={
"chosen_options": ["Approval"],
"params_input": {},
"responded_by_user": {"id": "admin", "name": "admin"},
"response_received": True,
"responded_at": "2025-07-03T00:00:00Z",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.hitl.update_response(
ti_id=ti_id,
chosen_options=["Approve"],
params_input={},
)
assert isinstance(result, HITLDetailResponse)
assert result.response_received is True
assert result.chosen_options == ["Approval"]
assert result.params_input == {}
assert result.responded_by_user == HITLUser(id="admin", name="admin")
assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0)
def test_get_detail_response(self, time_machine: TimeMachineFixture) -> None:
time_machine.move_to(datetime(2025, 7, 3, 0, 0, 0))
ti_id = uuid7()
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path in (f"/hitlDetails/{ti_id}"):
return httpx.Response(
status_code=200,
json={
"chosen_options": ["Approval"],
"params_input": {},
"responded_by_user": {"id": "admin", "name": "admin"},
"response_received": True,
"responded_at": "2025-07-03T00:00:00Z",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
result = client.hitl.get_detail_response(ti_id=ti_id)
assert isinstance(result, HITLDetailResponse)
assert result.response_received is True
assert result.chosen_options == ["Approval"]
assert result.params_input == {}
assert result.responded_by_user == HITLUser(id="admin", name="admin")
assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0)