Fix typing in external task triggers (#31490)
diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py
index 6099dc0..fc70a63 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -19,7 +19,6 @@
import asyncio
import datetime
import typing
-from typing import Any
from asgiref.sync import sync_to_async
from sqlalchemy import func
@@ -27,7 +26,7 @@
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
class TaskStateTrigger(BaseTrigger):
@@ -59,7 +58,7 @@
self.execution_dates = execution_dates
self.poll_interval = poll_interval
- def serialize(self) -> tuple[str, dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
return (
"airflow.triggers.external_task.TaskStateTrigger",
@@ -85,7 +84,7 @@
@sync_to_async
@provide_session
- def count_tasks(self, session: Session) -> int | None:
+ def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many task instances in the database match our criteria."""
count = (
session.query(func.count("*")) # .count() is inefficient
@@ -124,7 +123,7 @@
self.execution_dates = execution_dates
self.poll_interval = poll_interval
- def serialize(self) -> tuple[str, dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serializes DagStateTrigger arguments and classpath."""
return (
"airflow.triggers.external_task.DagStateTrigger",
@@ -149,7 +148,7 @@
@sync_to_async
@provide_session
- def count_dags(self, session: Session) -> int | None:
+ def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many dag runs in the database match our criteria."""
count = (
session.query(func.count("*")) # .count() is inefficient