blob: 1bec3ba70b897fe5347645fee7d3e0dd8adfeaca [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import os.path
import subprocess
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.resolve()))
from in_container_utils import console, get_provider_base_dir_from_path, get_provider_id_from_path
errors_found = False
def check_imports():
global errors_found
cmd = [
"ruff",
"analyze",
"graph",
"--python",
sys.executable,
]
console.print("Cmd", cmd)
import_tree_str = subprocess.check_output(cmd)
import_tree = json.loads(import_tree_str)
# Uncomment these if you want to debug strange dependencies and see if ruff gets it right
console.print("Dependencies discovered by ruff:")
console.print(import_tree)
for importing_file in sys.argv[1:]:
if not importing_file.startswith("providers/"):
console.print(f"[yellow]Skipping non-provider file: {importing_file}")
continue
importing_file_path = Path(importing_file)
console.print(importing_file_path)
imported_files_array = import_tree.get(importing_file, None)
if imported_files_array is None:
continue
imported_file_paths = [Path(file) for file in imported_files_array]
for imported_file_path in imported_file_paths:
if imported_file_path.name == "version_compat.py":
# Note - this will check also imports from other places - not only from providers
# Which means that import from tests_common, and airflow will be also banned
common_path = os.path.commonpath([importing_file, imported_file_path.as_posix()])
imported_file_parent_dir = imported_file_path.parent.as_posix()
if common_path != imported_file_parent_dir:
provider_id = get_provider_id_from_path(importing_file_path)
provider_dir = get_provider_base_dir_from_path(importing_file_path)
console.print(
f"\n[red]Invalid import of `version_compat` module in provider {provider_id} in:\n"
)
console.print(f"[yellow]{importing_file_path}")
console.print(
f"\n[bright_blue]The AIRFLOW_V_X_Y_PLUS import should be "
f"from the {provider_id} provider root directory ({provider_dir}), but it is currently from:"
)
console.print(f"\n[yellow]{imported_file_path}\n")
console.print(
f"1. Copy `version_compat`.py to `{provider_dir}/version_compat.py` if not there.\n"
f"2. Import the version constants you need as:\n\n"
f"[yellow]from airflow.providers.{provider_id}.version_compat import ...[/]\n"
f"\n"
)
errors_found = True
check_imports()
if errors_found:
console.print("\n[red]Errors found in imports![/]\n")
sys.exit(1)
else:
console.print("\n[green]All version_compat imports are correct![/]\n")