blob: e9744134360d5cf0f78c6335816763047408862b [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Check that defaults in Schema JSON match the server-side SerializedBaseOperator defaults.
This ensures that the schema accurately reflects the actual default values
used by the server-side serialization layer.
"""
from __future__ import annotations
import json
import sys
from datetime import timedelta
from pathlib import Path
from typing import Any
def load_schema_defaults(object_type: str = "operator") -> dict[str, Any]:
"""Load default values from the JSON schema."""
schema_path = Path("airflow-core/src/airflow/serialization/schema.json")
if not schema_path.exists():
print(f"Error: Schema file not found at {schema_path}")
sys.exit(1)
with open(schema_path) as f:
schema = json.load(f)
# Extract defaults from the specified object type definition
object_def = schema.get("definitions", {}).get(object_type, {})
properties = object_def.get("properties", {})
defaults = {}
for field_name, field_def in properties.items():
if "default" in field_def:
defaults[field_name] = field_def["default"]
return defaults
def get_server_side_operator_defaults() -> dict[str, Any]:
"""Get default values from server-side SerializedBaseOperator class."""
try:
from airflow.serialization.serialized_objects import SerializedBaseOperator
# Get all serializable fields
serialized_fields = SerializedBaseOperator.get_serialized_fields()
# Field name mappings from external API names to internal class attribute names
field_mappings = {
"weight_rule": "_weight_rule",
}
server_defaults = {}
for field_name in serialized_fields:
# Use the mapped internal name if it exists, otherwise use the field name
attr_name = field_mappings.get(field_name, field_name)
if hasattr(SerializedBaseOperator, attr_name):
default_value = getattr(SerializedBaseOperator, attr_name)
# Only include actual default values, not methods/properties/descriptors
if not callable(default_value) and not isinstance(default_value, (property, type)):
if isinstance(default_value, (set, tuple)):
# Convert to list since schema.json is pure JSON
default_value = list(default_value)
elif isinstance(default_value, timedelta):
default_value = default_value.total_seconds()
server_defaults[field_name] = default_value
return server_defaults
except ImportError as e:
print(f"Error importing SerializedBaseOperator: {e}")
sys.exit(1)
except Exception as e:
print(f"Error getting server-side defaults: {e}")
sys.exit(1)
def get_server_side_dag_defaults() -> dict[str, Any]:
"""Get default values from server-side SerializedDAG class."""
try:
from airflow.serialization.serialized_objects import SerializedDAG
# DAG defaults are set in __init__, so we create a temporary instance
temp_dag = SerializedDAG(dag_id="temp")
# Get all serializable DAG fields from the server-side class
serialized_fields = SerializedDAG.get_serialized_fields()
server_defaults = {}
for field_name in serialized_fields:
if hasattr(temp_dag, field_name):
default_value = getattr(temp_dag, field_name)
# Only include actual default values that are not None, callables, or descriptors
if not callable(default_value) and not isinstance(default_value, (property, type)):
if isinstance(default_value, (set, tuple)):
# Convert to list since schema.json is pure JSON
default_value = list(default_value)
server_defaults[field_name] = default_value
return server_defaults
except ImportError as e:
print(f"Error importing SerializedDAG: {e}")
sys.exit(1)
except Exception as e:
print(f"Error getting server-side DAG defaults: {e}")
sys.exit(1)
def compare_operator_defaults() -> list[str]:
"""Compare operator schema defaults with server-side defaults and return discrepancies."""
schema_defaults = load_schema_defaults("operator")
server_defaults = get_server_side_operator_defaults()
errors = []
print(f"Found {len(schema_defaults)} operator schema defaults")
print(f"Found {len(server_defaults)} operator server-side defaults")
# Check each server default against schema
for field_name, server_value in server_defaults.items():
schema_value = schema_defaults.get(field_name)
# Check if field exists in schema
if field_name not in schema_defaults:
# Some server fields might not need defaults in schema (like None values)
if server_value is not None and server_value not in [[], {}, (), set()]:
errors.append(
f"Server field '{field_name}' has default {server_value!r} but no schema default"
)
continue
# Direct comparison - no complex normalization needed
if schema_value != server_value:
errors.append(
f"Field '{field_name}': schema default is {schema_value!r}, "
f"server default is {server_value!r}"
)
# Check for schema defaults that don't have corresponding server defaults
for field_name, schema_value in schema_defaults.items():
if field_name not in server_defaults:
# Some schema fields are structural and don't need server defaults
schema_only_fields = {
"task_type",
"_task_module",
"task_id",
"_task_display_name",
"_is_mapped",
"_is_sensor",
}
if field_name not in schema_only_fields:
errors.append(
f"Schema has default for '{field_name}' = {schema_value!r} but no corresponding server default"
)
return errors
def compare_dag_defaults() -> list[str]:
"""Compare DAG schema defaults with server-side defaults and return discrepancies."""
schema_defaults = load_schema_defaults("dag")
server_defaults = get_server_side_dag_defaults()
errors = []
print(f"Found {len(schema_defaults)} DAG schema defaults")
print(f"Found {len(server_defaults)} DAG server-side defaults")
# Check each server default against schema
for field_name, server_value in server_defaults.items():
schema_value = schema_defaults.get(field_name)
# Check if field exists in schema
if field_name not in schema_defaults:
# Some server fields don't need defaults in schema (like None values, empty collections, or computed fields)
if (
server_value is not None
and server_value not in [[], {}, (), set()]
and field_name not in ["dag_id", "dag_display_name"]
):
errors.append(
f"DAG server field '{field_name}' has default {server_value!r} but no schema default"
)
continue
# Direct comparison
if schema_value != server_value:
errors.append(
f"DAG field '{field_name}': schema default is {schema_value!r}, "
f"server default is {server_value!r}"
)
# Check for schema defaults that don't have corresponding server defaults
for field_name, schema_value in schema_defaults.items():
if field_name not in server_defaults:
# Some schema fields are computed properties (like has_on_*_callback)
computed_properties = {
"has_on_success_callback",
"has_on_failure_callback",
}
if field_name not in computed_properties:
errors.append(
f"DAG schema has default for '{field_name}' = {schema_value!r} but no corresponding server default"
)
return errors
def main():
"""Main function to run the schema defaults check."""
print("Checking schema defaults against server-side serialization classes...")
# Check Operator defaults
print("\n1. Checking Operator defaults...")
operator_errors = compare_operator_defaults()
# Check Dag defaults
print("\n2. Checking Dag defaults...")
dag_errors = compare_dag_defaults()
all_errors = operator_errors + dag_errors
if all_errors:
print("\n❌ Found discrepancies between schema and server defaults:")
for error in all_errors:
print(f" • {error}")
print()
print("To fix these issues:")
print("1. Update airflow-core/src/airflow/serialization/schema.json to match server defaults, OR")
print(
"2. Update airflow-core/src/airflow/serialization/serialized_objects.py class/init defaults to match schema"
)
sys.exit(1)
else:
print("\n✅ All schema defaults match server-side defaults!")
if __name__ == "__main__":
main()