blob: 522399d1fcbbd7323c018b4a967739943e0356c6 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
import pytest
from dateutil import relativedelta
from airflow.decorators import task
from airflow.decorators.python import _PythonDecoratedOperator
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetEvent,
DatasetModel,
TaskOutletDatasetReference,
)
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
pytestmark = pytest.mark.db_test
pytest.importorskip("pydantic", minversion="2.0.0")
@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_serializing_pydantic_task_instance(session, create_task_instance):
dag_id = "test-dag"
ti = create_task_instance(dag_id=dag_id, session=session)
ti.state = State.RUNNING
ti.next_kwargs = {"foo": "bar"}
session.commit()
pydantic_task_instance = TaskInstancePydantic.model_validate(ti)
json_string = pydantic_task_instance.model_dump_json()
print(json_string)
deserialized_model = TaskInstancePydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING
assert deserialized_model.try_number == ti.try_number
assert deserialized_model.execution_date == ti.execution_date
assert deserialized_model.next_kwargs == {"foo": "bar"}
@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker):
op_class_dict_expected = {
"_needs_expansion": True,
"_task_type": "_PythonDecoratedOperator",
"downstream_task_ids": [],
"next_method": None,
"start_trigger": None,
"_operator_name": "@task",
"ui_fgcolor": "#000",
"ui_color": "#ffefeb",
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
"template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"},
"template_ext": [],
"task_id": "target",
}
with dag_maker():
@task
def source():
return [1, 2, 3]
@task
def target(val=None):
print(val)
# source() >> target()
target.expand(val=source())
dr = dag_maker.create_dagrun()
ti = dr.task_instances[1]
# roundtrip task
ser_task = BaseSerialization.serialize(ti.task, use_pydantic_models=True)
deser_task = BaseSerialization.deserialize(ser_task, use_pydantic_models=True)
ti.task.operator_class
# this is part of the problem!
assert isinstance(ti.task.operator_class, type)
assert isinstance(deser_task.operator_class, dict)
assert ti.task.operator_class == _PythonDecoratedOperator
ti.refresh_from_task(deser_task)
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)
assert "operator_class" not in sered["__var"]["task"]
assert desered.task.__class__ == MappedOperator
assert desered.task.operator_class == op_class_dict_expected
desered.refresh_from_task(deser_task)
assert desered.task.__class__ == MappedOperator
assert isinstance(desered.task.operator_class, dict)
resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected
@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_serializing_pydantic_dagrun(session, create_task_instance):
dag_id = "test-dag"
ti = create_task_instance(dag_id=dag_id, session=session)
ti.dag_run.state = State.RUNNING
session.commit()
pydantic_dag_run = DagRunPydantic.model_validate(ti.dag_run)
json_string = pydantic_dag_run.model_dump_json()
print(json_string)
deserialized_model = DagRunPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING
@pytest.mark.parametrize(
"schedule_interval",
[
None,
"*/10 * * *",
datetime.timedelta(days=1),
relativedelta.relativedelta(days=+12),
],
)
def test_serializing_pydantic_dagmodel(schedule_interval):
dag_model = DagModel(
dag_id="test-dag",
fileloc="/tmp/dag_1.py",
schedule_interval=schedule_interval,
is_active=True,
is_paused=False,
)
pydantic_dag_model = DagModelPydantic.model_validate(dag_model)
json_string = pydantic_dag_model.model_dump_json()
deserialized_model = DagModelPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == "test-dag"
assert deserialized_model.fileloc == "/tmp/dag_1.py"
assert deserialized_model.schedule_interval == schedule_interval
assert deserialized_model.is_active is True
assert deserialized_model.is_paused is False
def test_serializing_pydantic_local_task_job(session, create_task_instance):
dag_id = "test-dag"
ti = create_task_instance(dag_id=dag_id, session=session)
ltj = Job(dag_id=ti.dag_id)
LocalTaskJobRunner(job=ltj, task_instance=ti)
ltj.state = State.RUNNING
session.commit()
pydantic_job = JobPydantic.model_validate(ltj)
json_string = pydantic_job.model_dump_json()
deserialized_model = JobPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.job_type == "LocalTaskJob"
assert deserialized_model.state == State.RUNNING
@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_serializing_pydantic_dataset_event(session, create_task_instance, create_dummy_dag):
ds1 = DatasetModel(id=1, uri="one", extra={"foo": "bar"})
ds2 = DatasetModel(id=2, uri="two")
session.add_all([ds1, ds2])
session.commit()
# it's easier to fake a manual run here
dag, task1 = create_dummy_dag(
dag_id="test_triggering_dataset_events",
schedule=None,
start_date=DEFAULT_DATE,
task_id="test_context",
with_dagrun_type=DagRunType.MANUAL,
session=session,
)
dr = dag.create_dagrun(
run_id="test2",
run_type=DagRunType.DATASET_TRIGGERED,
execution_date=timezone.utcnow(),
state=None,
session=session,
)
ds1_event = DatasetEvent(dataset_id=1)
ds2_event_1 = DatasetEvent(dataset_id=2)
ds2_event_2 = DatasetEvent(dataset_id=2)
dag_ds_ref = DagScheduleDatasetReference(dag_id=dag.dag_id)
session.add(dag_ds_ref)
dag_ds_ref.dataset = ds1
task_ds_ref = TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id)
session.add(task_ds_ref)
task_ds_ref.dataset = ds1
dr.consumed_dataset_events.append(ds1_event)
dr.consumed_dataset_events.append(ds2_event_1)
dr.consumed_dataset_events.append(ds2_event_2)
session.commit()
print(ds2_event_2.dataset.consuming_dags)
pydantic_dse1 = DatasetEventPydantic.model_validate(ds1_event)
json_string1 = pydantic_dse1.model_dump_json()
print(json_string1)
pydantic_dse2 = DatasetEventPydantic.model_validate(ds2_event_1)
json_string2 = pydantic_dse2.model_dump_json()
print(json_string2)
pydantic_dag_run = DagRunPydantic.model_validate(dr)
json_string_dr = pydantic_dag_run.model_dump_json()
print(json_string_dr)
deserialized_model1 = DatasetEventPydantic.model_validate_json(json_string1)
assert deserialized_model1.dataset.id == 1
assert deserialized_model1.dataset.uri == "one"
assert len(deserialized_model1.dataset.consuming_dags) == 1
assert len(deserialized_model1.dataset.producing_tasks) == 1
deserialized_model2 = DatasetEventPydantic.model_validate_json(json_string2)
assert deserialized_model2.dataset.id == 2
assert deserialized_model2.dataset.uri == "two"
assert len(deserialized_model2.dataset.consuming_dags) == 0
assert len(deserialized_model2.dataset.producing_tasks) == 0
deserialized_dr = DagRunPydantic.model_validate_json(json_string_dr)
assert len(deserialized_dr.consumed_dataset_events) == 3