| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import json |
| import os |
| import sys |
| from enum import Enum |
| from functools import cached_property, lru_cache |
| from re import match |
| from typing import Any, Dict, List, TypeVar |
| |
| if sys.version_info >= (3, 9): |
| from typing import Literal |
| else: |
| from typing import Literal |
| |
| from airflow_breeze.global_constants import ( |
| ALL_PYTHON_MAJOR_MINOR_VERSIONS, |
| APACHE_AIRFLOW_GITHUB_REPOSITORY, |
| COMMITTERS, |
| CURRENT_KUBERNETES_VERSIONS, |
| CURRENT_MSSQL_VERSIONS, |
| CURRENT_MYSQL_VERSIONS, |
| CURRENT_POSTGRES_VERSIONS, |
| CURRENT_PYTHON_MAJOR_MINOR_VERSIONS, |
| DEFAULT_KUBERNETES_VERSION, |
| DEFAULT_MSSQL_VERSION, |
| DEFAULT_MYSQL_VERSION, |
| DEFAULT_POSTGRES_VERSION, |
| DEFAULT_PYTHON_MAJOR_MINOR_VERSION, |
| HELM_VERSION, |
| KIND_VERSION, |
| RUNS_ON_PUBLIC_RUNNER, |
| RUNS_ON_SELF_HOSTED_RUNNER, |
| SELF_HOSTED_RUNNERS_CPU_COUNT, |
| GithubEvents, |
| SelectiveUnitTestTypes, |
| all_helm_test_packages, |
| all_selective_test_types, |
| ) |
| from airflow_breeze.utils.console import get_console |
| from airflow_breeze.utils.exclude_from_matrix import excluded_combos |
| from airflow_breeze.utils.github_actions import get_ga_output |
| from airflow_breeze.utils.kubernetes_utils import get_kubernetes_python_combos |
| from airflow_breeze.utils.path_utils import ( |
| AIRFLOW_PROVIDERS_ROOT, |
| AIRFLOW_SOURCES_ROOT, |
| DOCS_DIR, |
| SYSTEM_TESTS_PROVIDERS_ROOT, |
| TESTS_PROVIDERS_ROOT, |
| ) |
| from airflow_breeze.utils.provider_dependencies import DEPENDENCIES, get_related_providers |
| |
| FULL_TESTS_NEEDED_LABEL = "full tests needed" |
| DEBUG_CI_RESOURCES_LABEL = "debug ci resources" |
| USE_PUBLIC_RUNNERS_LABEL = "use public runners" |
| |
| |
| class FileGroupForCi(Enum): |
| ENVIRONMENT_FILES = "environment_files" |
| PYTHON_PRODUCTION_FILES = "python_scans" |
| JAVASCRIPT_PRODUCTION_FILES = "javascript_scans" |
| API_TEST_FILES = "api_test_files" |
| API_CODEGEN_FILES = "api_codegen_files" |
| HELM_FILES = "helm_files" |
| SETUP_FILES = "setup_files" |
| DOC_FILES = "doc_files" |
| WWW_FILES = "www_files" |
| SYSTEM_TEST_FILES = "system_tests" |
| KUBERNETES_FILES = "kubernetes_files" |
| ALL_PYTHON_FILES = "all_python_files" |
| ALL_SOURCE_FILES = "all_sources_for_tests" |
| |
| |
| T = TypeVar("T", FileGroupForCi, SelectiveUnitTestTypes) |
| |
| |
| class HashableDict(Dict[T, List[str]]): |
| def __hash__(self): |
| return hash(frozenset(self)) |
| |
| |
| CI_FILE_GROUP_MATCHES = HashableDict( |
| { |
| FileGroupForCi.ENVIRONMENT_FILES: [ |
| r"^.github/workflows", |
| r"^dev/breeze", |
| r"^dev/.*\.py$", |
| r"^Dockerfile", |
| r"^scripts", |
| r"^setup.py", |
| r"^setup.cfg", |
| r"^generated/provider_dependencies.json$", |
| ], |
| FileGroupForCi.PYTHON_PRODUCTION_FILES: [ |
| r"^airflow/.*\.py", |
| r"^setup.py", |
| ], |
| FileGroupForCi.JAVASCRIPT_PRODUCTION_FILES: [ |
| r"^airflow/.*\.[jt]sx?", |
| r"^airflow/.*\.lock", |
| ], |
| FileGroupForCi.API_TEST_FILES: [ |
| r"^airflow/api/", |
| r"^airflow/api_connexion/", |
| ], |
| FileGroupForCi.API_CODEGEN_FILES: [ |
| r"^airflow/api_connexion/openapi/v1\.yaml", |
| r"^clients/gen", |
| ], |
| FileGroupForCi.HELM_FILES: [ |
| r"^chart", |
| r"^airflow/kubernetes", |
| r"^tests/kubernetes", |
| r"^tests/charts", |
| ], |
| FileGroupForCi.SETUP_FILES: [ |
| r"^pyproject.toml", |
| r"^setup.cfg", |
| r"^setup.py", |
| r"^generated/provider_dependencies.json$", |
| r"^airflow/providers/.*/provider.yaml$", |
| ], |
| FileGroupForCi.DOC_FILES: [ |
| r"^docs", |
| r"^\.github/SECURITY\.rst$", |
| r"^airflow/.*\.py$", |
| r"^chart", |
| r"^providers", |
| r"^tests/system", |
| r"^CHANGELOG\.txt", |
| r"^airflow/config_templates/config\.yml", |
| r"^chart/RELEASE_NOTES\.txt", |
| r"^chart/values\.schema\.json", |
| r"^chart/values\.json", |
| ], |
| FileGroupForCi.WWW_FILES: [ |
| r"^airflow/www/.*\.ts[x]?$", |
| r"^airflow/www/.*\.js[x]?$", |
| r"^airflow/www/[^/]+\.json$", |
| r"^airflow/www/.*\.lock$", |
| ], |
| FileGroupForCi.KUBERNETES_FILES: [ |
| r"^chart", |
| r"^kubernetes_tests", |
| r"^airflow/providers/cncf/kubernetes/", |
| r"^tests/providers/cncf/kubernetes/", |
| r"^tests/system/providers/cncf/kubernetes/", |
| ], |
| FileGroupForCi.ALL_PYTHON_FILES: [ |
| r"\.py$", |
| ], |
| FileGroupForCi.ALL_SOURCE_FILES: [ |
| r"^.pre-commit-config.yaml$", |
| r"^airflow", |
| r"^chart", |
| r"^tests", |
| r"^kubernetes_tests", |
| ], |
| FileGroupForCi.SYSTEM_TEST_FILES: [ |
| r"^tests/system/", |
| ], |
| } |
| ) |
| |
| |
| TEST_TYPE_MATCHES = HashableDict( |
| { |
| SelectiveUnitTestTypes.API: [ |
| r"^airflow/api", |
| r"^airflow/api_connexion", |
| r"^tests/api", |
| r"^tests/api_connexion", |
| ], |
| SelectiveUnitTestTypes.CLI: [ |
| r"^airflow/cli", |
| r"^tests/cli", |
| ], |
| SelectiveUnitTestTypes.PROVIDERS: [ |
| r"^airflow/providers/", |
| r"^tests/system/providers/", |
| r"^tests/providers/", |
| ], |
| SelectiveUnitTestTypes.WWW: [r"^airflow/www", r"^tests/www"], |
| } |
| ) |
| |
| |
| def find_provider_affected(changed_file: str, include_docs: bool) -> str | None: |
| file_path = AIRFLOW_SOURCES_ROOT / changed_file |
| # is_relative_to is only available in Python 3.9 - we should simplify this check when we are Python 3.9+ |
| for provider_root in (TESTS_PROVIDERS_ROOT, SYSTEM_TESTS_PROVIDERS_ROOT, AIRFLOW_PROVIDERS_ROOT): |
| try: |
| file_path.relative_to(provider_root) |
| relative_base_path = provider_root |
| break |
| except ValueError: |
| pass |
| else: |
| if include_docs: |
| try: |
| relative_path = file_path.relative_to(DOCS_DIR) |
| if relative_path.parts[0].startswith("apache-airflow-providers-"): |
| return relative_path.parts[0].replace("apache-airflow-providers-", "").replace("-", ".") |
| except ValueError: |
| pass |
| return None |
| |
| for parent_dir_path in file_path.parents: |
| if parent_dir_path == relative_base_path: |
| break |
| relative_path = parent_dir_path.relative_to(relative_base_path) |
| if (AIRFLOW_PROVIDERS_ROOT / relative_path / "provider.yaml").exists(): |
| return str(parent_dir_path.relative_to(relative_base_path)).replace(os.sep, ".") |
| # If we got here it means that some "common" files were modified. so we need to test all Providers |
| return "Providers" |
| |
| |
| def find_all_providers_affected( |
| changed_files: tuple[str, ...], include_docs: bool, fail_if_suspended_providers_affected: bool |
| ) -> list[str] | Literal["ALL_PROVIDERS"] | None: |
| all_providers: set[str] = set() |
| |
| all_providers_affected = False |
| suspended_providers: set[str] = set() |
| for changed_file in changed_files: |
| provider = find_provider_affected(changed_file, include_docs=include_docs) |
| if provider == "Providers": |
| all_providers_affected = True |
| elif provider is not None: |
| if provider not in DEPENDENCIES: |
| suspended_providers.add(provider) |
| else: |
| all_providers.add(provider) |
| if all_providers_affected: |
| return "ALL_PROVIDERS" |
| if suspended_providers: |
| # We check for suspended providers only after we have checked if all providers are affected. |
| # No matter if we found that we are modifying a suspended provider individually, if all providers are |
| # affected, then it means that we are ok to proceed because likely we are running some kind of |
| # global refactoring that affects multiple providers including the suspended one. This is a |
| # potential escape hatch if someone would like to modify suspended provider, |
| # but it can be found at the review time and is anyway harmless as the provider will not be |
| # released nor tested nor used in CI anyway. |
| get_console().print("[yellow]You are modifying suspended providers.\n") |
| get_console().print( |
| "[info]Some providers modified by this change have been suspended, " |
| "and before attempting such changes you should fix the reason for suspension." |
| ) |
| get_console().print( |
| "[info]When fixing it, you should set suspended = false in provider.yaml " |
| "to make changes to the provider." |
| ) |
| get_console().print(f"Suspended providers: {suspended_providers}") |
| if fail_if_suspended_providers_affected: |
| get_console().print( |
| "[error]This PR did not have `allow suspended provider changes` label set so it will fail." |
| ) |
| sys.exit(1) |
| else: |
| get_console().print( |
| "[info]This PR had `allow suspended provider changes` label set so it will continue" |
| ) |
| if len(all_providers) == 0: |
| return None |
| for provider in list(all_providers): |
| all_providers.update( |
| get_related_providers(provider, upstream_dependencies=True, downstream_dependencies=True) |
| ) |
| return sorted(all_providers) |
| |
| |
| class SelectiveChecks: |
| __HASHABLE_FIELDS = {"_files", "_default_branch", "_commit_ref", "_pr_labels", "_github_event"} |
| |
| def __init__( |
| self, |
| files: tuple[str, ...] = (), |
| default_branch="main", |
| default_constraints_branch="constraints-main", |
| commit_ref: str | None = None, |
| pr_labels: tuple[str, ...] = (), |
| github_event: GithubEvents = GithubEvents.PULL_REQUEST, |
| github_repository: str = APACHE_AIRFLOW_GITHUB_REPOSITORY, |
| github_actor: str = "", |
| github_context_dict: dict[str, Any] | None = None, |
| ): |
| self._files = files |
| self._default_branch = default_branch |
| self._default_constraints_branch = default_constraints_branch |
| self._commit_ref = commit_ref |
| self._pr_labels = pr_labels |
| self._github_event = github_event |
| self._github_repository = github_repository |
| self._github_actor = github_actor |
| self._github_context_dict = github_context_dict or {} |
| |
| def __important_attributes(self) -> tuple[Any, ...]: |
| return tuple(getattr(self, f) for f in self.__HASHABLE_FIELDS) |
| |
| def __hash__(self): |
| return hash(self.__important_attributes()) |
| |
| def __eq__(self, other): |
| return isinstance(other, SelectiveChecks) and all( |
| [getattr(other, f) == getattr(self, f) for f in self.__HASHABLE_FIELDS] |
| ) |
| |
| def __str__(self) -> str: |
| output = [] |
| for field_name in dir(self): |
| if not field_name.startswith("_"): |
| value = getattr(self, field_name) |
| if value is not None: |
| output.append(get_ga_output(field_name, value)) |
| return "\n".join(output) |
| |
| default_python_version = DEFAULT_PYTHON_MAJOR_MINOR_VERSION |
| default_postgres_version = DEFAULT_POSTGRES_VERSION |
| default_mysql_version = DEFAULT_MYSQL_VERSION |
| default_mssql_version = DEFAULT_MSSQL_VERSION |
| |
| default_kubernetes_version = DEFAULT_KUBERNETES_VERSION |
| default_kind_version = KIND_VERSION |
| default_helm_version = HELM_VERSION |
| |
| @cached_property |
| def default_branch(self) -> str: |
| return self._default_branch |
| |
| @cached_property |
| def default_constraints_branch(self) -> str: |
| return self._default_constraints_branch |
| |
| @cached_property |
| def full_tests_needed(self) -> bool: |
| if not self._commit_ref: |
| get_console().print("[warning]Running everything as commit is missing[/]") |
| return True |
| if self._github_event in [GithubEvents.PUSH, GithubEvents.SCHEDULE, GithubEvents.WORKFLOW_DISPATCH]: |
| get_console().print(f"[warning]Full tests needed because event is {self._github_event}[/]") |
| return True |
| if len(self._matching_files(FileGroupForCi.ENVIRONMENT_FILES, CI_FILE_GROUP_MATCHES)) > 0: |
| get_console().print("[warning]Running everything because env files changed[/]") |
| return True |
| if FULL_TESTS_NEEDED_LABEL in self._pr_labels: |
| get_console().print( |
| "[warning]Full tests needed because " |
| f"label '{FULL_TESTS_NEEDED_LABEL}' is in {self._pr_labels}[/]" |
| ) |
| return True |
| return False |
| |
| @cached_property |
| def python_versions(self) -> list[str]: |
| return ( |
| CURRENT_PYTHON_MAJOR_MINOR_VERSIONS |
| if self.full_tests_needed |
| else [DEFAULT_PYTHON_MAJOR_MINOR_VERSION] |
| ) |
| |
| @cached_property |
| def python_versions_list_as_string(self) -> str: |
| return " ".join(self.python_versions) |
| |
| @cached_property |
| def all_python_versions(self) -> list[str]: |
| return ( |
| ALL_PYTHON_MAJOR_MINOR_VERSIONS |
| if self.full_tests_needed |
| else [DEFAULT_PYTHON_MAJOR_MINOR_VERSION] |
| ) |
| |
| @cached_property |
| def all_python_versions_list_as_string(self) -> str: |
| return " ".join(self.all_python_versions) |
| |
| @cached_property |
| def postgres_versions(self) -> list[str]: |
| return CURRENT_POSTGRES_VERSIONS if self.full_tests_needed else [DEFAULT_POSTGRES_VERSION] |
| |
| @cached_property |
| def mysql_versions(self) -> list[str]: |
| return CURRENT_MYSQL_VERSIONS if self.full_tests_needed else [DEFAULT_MYSQL_VERSION] |
| |
| @cached_property |
| def mssql_versions(self) -> list[str]: |
| return CURRENT_MSSQL_VERSIONS if self.full_tests_needed else [DEFAULT_MSSQL_VERSION] |
| |
| @cached_property |
| def kind_version(self) -> str: |
| return KIND_VERSION |
| |
| @cached_property |
| def helm_version(self) -> str: |
| return HELM_VERSION |
| |
| @cached_property |
| def postgres_exclude(self) -> list[dict[str, str]]: |
| if not self.full_tests_needed: |
| # Only basic combination so we do not need to exclude anything |
| return [] |
| return [ |
| # Exclude all combinations that are repeating python/postgres versions |
| {"python-version": python_version, "postgres-version": postgres_version} |
| for python_version, postgres_version in excluded_combos( |
| CURRENT_PYTHON_MAJOR_MINOR_VERSIONS, CURRENT_POSTGRES_VERSIONS |
| ) |
| ] |
| |
| @cached_property |
| def mssql_exclude(self) -> list[dict[str, str]]: |
| if not self.full_tests_needed: |
| # Only basic combination so we do not need to exclude anything |
| return [] |
| return [ |
| # Exclude all combinations that are repeating python/mssql versions |
| {"python-version": python_version, "mssql-version": mssql_version} |
| for python_version, mssql_version in excluded_combos( |
| CURRENT_PYTHON_MAJOR_MINOR_VERSIONS, CURRENT_MSSQL_VERSIONS |
| ) |
| ] |
| |
| @cached_property |
| def mysql_exclude(self) -> list[dict[str, str]]: |
| if not self.full_tests_needed: |
| # Only basic combination so we do not need to exclude anything |
| return [] |
| return [ |
| # Exclude all combinations that are repeating python/mysql versions |
| {"python-version": python_version, "mysql-version": mysql_version} |
| for python_version, mysql_version in excluded_combos( |
| CURRENT_PYTHON_MAJOR_MINOR_VERSIONS, CURRENT_MYSQL_VERSIONS |
| ) |
| ] |
| |
| @cached_property |
| def sqlite_exclude(self) -> list[dict[str, str]]: |
| return [] |
| |
| @cached_property |
| def kubernetes_versions(self) -> list[str]: |
| return CURRENT_KUBERNETES_VERSIONS if self.full_tests_needed else [DEFAULT_KUBERNETES_VERSION] |
| |
| @cached_property |
| def kubernetes_versions_list_as_string(self) -> str: |
| return " ".join(self.kubernetes_versions) |
| |
| @cached_property |
| def kubernetes_combos_list_as_string(self) -> str: |
| python_version_array: list[str] = self.python_versions_list_as_string.split(" ") |
| kubernetes_version_array: list[str] = self.kubernetes_versions_list_as_string.split(" ") |
| combo_titles, short_combo_titles, combos = get_kubernetes_python_combos( |
| kubernetes_version_array, python_version_array |
| ) |
| return " ".join(short_combo_titles) |
| |
| def _match_files_with_regexps(self, matched_files, regexps): |
| for file in self._files: |
| for regexp in regexps: |
| if match(regexp, file): |
| matched_files.append(file) |
| break |
| |
| @lru_cache(maxsize=None) |
| def _matching_files(self, match_group: T, match_dict: dict[T, list[str]]) -> list[str]: |
| matched_files: list[str] = [] |
| regexps = match_dict[match_group] |
| self._match_files_with_regexps(matched_files, regexps) |
| count = len(matched_files) |
| if count > 0: |
| get_console().print(f"[warning]{match_group} matched {count} files.[/]") |
| get_console().print(matched_files) |
| else: |
| get_console().print(f"[warning]{match_group} did not match any file.[/]") |
| return matched_files |
| |
| def _should_be_run(self, source_area: FileGroupForCi) -> bool: |
| if self.full_tests_needed: |
| get_console().print(f"[warning]{source_area} enabled because we are running everything[/]") |
| return True |
| matched_files = self._matching_files(source_area, CI_FILE_GROUP_MATCHES) |
| if len(matched_files) > 0: |
| get_console().print( |
| f"[warning]{source_area} enabled because it matched {len(matched_files)} changed files[/]" |
| ) |
| return True |
| else: |
| get_console().print( |
| f"[warning]{source_area} disabled because it did not match any changed files[/]" |
| ) |
| return False |
| |
| @cached_property |
| def needs_python_scans(self) -> bool: |
| return self._should_be_run(FileGroupForCi.PYTHON_PRODUCTION_FILES) |
| |
| @cached_property |
| def needs_javascript_scans(self) -> bool: |
| return self._should_be_run(FileGroupForCi.JAVASCRIPT_PRODUCTION_FILES) |
| |
| @cached_property |
| def needs_api_tests(self) -> bool: |
| return self._should_be_run(FileGroupForCi.API_TEST_FILES) |
| |
| @cached_property |
| def needs_api_codegen(self) -> bool: |
| return self._should_be_run(FileGroupForCi.API_CODEGEN_FILES) |
| |
| @cached_property |
| def run_www_tests(self) -> bool: |
| return self._should_be_run(FileGroupForCi.WWW_FILES) |
| |
| @cached_property |
| def run_amazon_tests(self) -> bool: |
| if self.parallel_test_types_list_as_string is None: |
| return False |
| return ( |
| "amazon" in self.parallel_test_types_list_as_string |
| or "Providers" in self.parallel_test_types_list_as_string.split(" ") |
| ) |
| |
| @cached_property |
| def run_kubernetes_tests(self) -> bool: |
| return self._should_be_run(FileGroupForCi.KUBERNETES_FILES) |
| |
| @cached_property |
| def docs_build(self) -> bool: |
| return self._should_be_run(FileGroupForCi.DOC_FILES) |
| |
| @cached_property |
| def needs_helm_tests(self) -> bool: |
| return self._should_be_run(FileGroupForCi.HELM_FILES) and self._default_branch == "main" |
| |
| @cached_property |
| def run_tests(self) -> bool: |
| return self._should_be_run(FileGroupForCi.ALL_SOURCE_FILES) |
| |
| @cached_property |
| def image_build(self) -> bool: |
| return self.run_tests or self.docs_build or self.run_kubernetes_tests |
| |
| def _select_test_type_if_matching( |
| self, test_types: set[str], test_type: SelectiveUnitTestTypes |
| ) -> list[str]: |
| matched_files = self._matching_files(test_type, TEST_TYPE_MATCHES) |
| count = len(matched_files) |
| if count > 0: |
| test_types.add(test_type.value) |
| get_console().print(f"[warning]{test_type} added because it matched {count} files[/]") |
| return matched_files |
| |
| def _are_all_providers_affected(self) -> bool: |
| # if "Providers" test is present in the list of tests, it means that we should run all providers tests |
| # prepare all providers packages and build all providers documentation |
| return "Providers" in self._get_test_types_to_run() |
| |
| def _fail_if_suspended_providers_affected(self): |
| return "allow suspended provider changes" not in self._pr_labels |
| |
| def _get_test_types_to_run(self) -> list[str]: |
| if self.full_tests_needed: |
| return list(all_selective_test_types()) |
| |
| candidate_test_types: set[str] = {"Always"} |
| matched_files: set[str] = set() |
| matched_files.update( |
| self._select_test_type_if_matching(candidate_test_types, SelectiveUnitTestTypes.WWW) |
| ) |
| matched_files.update( |
| self._select_test_type_if_matching(candidate_test_types, SelectiveUnitTestTypes.PROVIDERS) |
| ) |
| matched_files.update( |
| self._select_test_type_if_matching(candidate_test_types, SelectiveUnitTestTypes.CLI) |
| ) |
| matched_files.update( |
| self._select_test_type_if_matching(candidate_test_types, SelectiveUnitTestTypes.API) |
| ) |
| |
| kubernetes_files = self._matching_files(FileGroupForCi.KUBERNETES_FILES, CI_FILE_GROUP_MATCHES) |
| system_test_files = self._matching_files(FileGroupForCi.SYSTEM_TEST_FILES, CI_FILE_GROUP_MATCHES) |
| all_source_files = self._matching_files(FileGroupForCi.ALL_SOURCE_FILES, CI_FILE_GROUP_MATCHES) |
| |
| remaining_files = ( |
| set(all_source_files) - set(matched_files) - set(kubernetes_files) - set(system_test_files) |
| ) |
| count_remaining_files = len(remaining_files) |
| if count_remaining_files > 0: |
| get_console().print( |
| f"[warning]We should run all tests. There are {count_remaining_files} changed " |
| "files that seems to fall into Core/Other category[/]" |
| ) |
| get_console().print(remaining_files) |
| candidate_test_types.update(all_selective_test_types()) |
| else: |
| if "Providers" in candidate_test_types: |
| affected_providers = find_all_providers_affected( |
| changed_files=self._files, |
| include_docs=False, |
| fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(), |
| ) |
| if affected_providers != "ALL_PROVIDERS" and affected_providers is not None: |
| candidate_test_types.remove("Providers") |
| candidate_test_types.add(f"Providers[{','.join(sorted(affected_providers))}]") |
| get_console().print( |
| "[warning]There are no core/other files. Only tests relevant to the changed files are run.[/]" |
| ) |
| sorted_candidate_test_types = list(sorted(candidate_test_types)) |
| get_console().print("[warning]Selected test type candidates to run:[/]") |
| get_console().print(sorted_candidate_test_types) |
| return sorted_candidate_test_types |
| |
| @staticmethod |
| def _extract_long_provider_tests(current_test_types: set[str]): |
| """ |
| In case there are Provider tests in the list of test to run - either in the form of |
| Providers or Providers[...] we subtract them from the test type, |
| and add them to the list of tests to run individually. |
| |
| In case of Providers, we need to replace it with Providers[-<list_of_long_tests>], but |
| in case of Providers[list_of_tests] we need to remove the long tests from the list. |
| |
| """ |
| long_tests = ["amazon", "google"] |
| for original_test_type in tuple(current_test_types): |
| if original_test_type == "Providers": |
| current_test_types.remove(original_test_type) |
| for long_test in long_tests: |
| current_test_types.add(f"Providers[{long_test}]") |
| current_test_types.add(f"Providers[-{','.join(long_tests)}]") |
| elif original_test_type.startswith("Providers["): |
| provider_tests_to_run = ( |
| original_test_type.replace("Providers[", "").replace("]", "").split(",") |
| ) |
| if any(long_test in provider_tests_to_run for long_test in long_tests): |
| current_test_types.remove(original_test_type) |
| for long_test in long_tests: |
| if long_test in provider_tests_to_run: |
| current_test_types.add(f"Providers[{long_test}]") |
| provider_tests_to_run.remove(long_test) |
| current_test_types.add(f"Providers[{','.join(provider_tests_to_run)}]") |
| |
| @cached_property |
| def parallel_test_types_list_as_string(self) -> str | None: |
| if not self.run_tests: |
| return None |
| current_test_types = set(self._get_test_types_to_run()) |
| if self._default_branch != "main": |
| test_types_to_remove: set[str] = set() |
| for test_type in current_test_types: |
| if test_type.startswith("Providers"): |
| get_console().print( |
| f"[warning]Removing {test_type} because the target branch " |
| f"is {self._default_branch} and not main[/]" |
| ) |
| test_types_to_remove.add(test_type) |
| current_test_types = current_test_types - test_types_to_remove |
| |
| self._extract_long_provider_tests(current_test_types) |
| |
| # this should be hard-coded as we want to have very specific sequence of tests |
| sorting_order = ["Core", "Providers[-amazon,google]", "Other", "Providers[amazon]", "WWW"] |
| |
| def sort_key(t: str) -> str: |
| # Put the test types in the order we want them to run |
| if t in sorting_order: |
| return str(sorting_order.index(t)) |
| else: |
| return str(len(sorting_order)) + t |
| |
| return " ".join( |
| sorted( |
| current_test_types, |
| key=sort_key, |
| ) |
| ) |
| |
| @cached_property |
| def basic_checks_only(self) -> bool: |
| return not self.image_build |
| |
| @cached_property |
| def upgrade_to_newer_dependencies(self) -> bool: |
| return len( |
| self._matching_files(FileGroupForCi.SETUP_FILES, CI_FILE_GROUP_MATCHES) |
| ) > 0 or self._github_event in [GithubEvents.PUSH, GithubEvents.SCHEDULE] |
| |
| @cached_property |
| def docs_filter_list_as_string(self) -> str | None: |
| _ALL_DOCS_LIST = "" |
| if not self.docs_build: |
| return None |
| if self._default_branch != "main": |
| return "--package-filter apache-airflow --package-filter docker-stack" |
| if self.full_tests_needed: |
| return _ALL_DOCS_LIST |
| providers_affected = find_all_providers_affected( |
| changed_files=self._files, |
| include_docs=True, |
| fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(), |
| ) |
| if ( |
| providers_affected == "ALL_PROVIDERS" |
| or "docs/conf.py" in self._files |
| or "docs/build_docs.py" in self._files |
| or self._are_all_providers_affected() |
| ): |
| return _ALL_DOCS_LIST |
| packages = [] |
| if any([file.startswith("airflow/") for file in self._files]): |
| packages.append("apache-airflow") |
| if any([file.startswith("chart/") or file.startswith("docs/helm-chart") for file in self._files]): |
| packages.append("helm-chart") |
| if any([file.startswith("docs/docker-stack/") for file in self._files]): |
| packages.append("docker-stack") |
| if providers_affected: |
| for provider in providers_affected: |
| packages.append(f"apache-airflow-providers-{provider.replace('.', '-')}") |
| return " ".join([f"--package-filter {package}" for package in packages]) |
| |
| @cached_property |
| def skip_pre_commits(self) -> str: |
| return ( |
| "identity" |
| if self._default_branch == "main" |
| else "identity,check-airflow-provider-compatibility," |
| "check-extra-packages-references,check-provider-yaml-valid" |
| ) |
| |
| @cached_property |
| def skip_provider_tests(self) -> bool: |
| if self._default_branch != "main": |
| return True |
| if self.full_tests_needed: |
| return False |
| if any(test_type.startswith("Providers") for test_type in self._get_test_types_to_run()): |
| return False |
| return True |
| |
| @cached_property |
| def cache_directive(self) -> str: |
| return "disabled" if self._github_event == GithubEvents.SCHEDULE else "registry" |
| |
| @cached_property |
| def debug_resources(self) -> bool: |
| return DEBUG_CI_RESOURCES_LABEL in self._pr_labels |
| |
| @cached_property |
| def helm_test_packages(self) -> str: |
| return json.dumps(all_helm_test_packages()) |
| |
| @cached_property |
| def affected_providers_list_as_string(self) -> str | None: |
| _ALL_PROVIDERS_LIST = "" |
| if self.full_tests_needed: |
| return _ALL_PROVIDERS_LIST |
| if self._are_all_providers_affected(): |
| return _ALL_PROVIDERS_LIST |
| affected_providers = find_all_providers_affected( |
| changed_files=self._files, |
| include_docs=True, |
| fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(), |
| ) |
| if not affected_providers: |
| return None |
| if affected_providers == "ALL_PROVIDERS": |
| return _ALL_PROVIDERS_LIST |
| return " ".join(sorted(affected_providers)) |
| |
| @cached_property |
| def runs_on(self) -> str: |
| if self._github_repository == APACHE_AIRFLOW_GITHUB_REPOSITORY: |
| if self._github_event in [GithubEvents.SCHEDULE, GithubEvents.PUSH]: |
| return RUNS_ON_SELF_HOSTED_RUNNER |
| actor = self._github_actor |
| if self._github_event in (GithubEvents.PULL_REQUEST, GithubEvents.PULL_REQUEST_TARGET): |
| try: |
| actor = self._github_context_dict["event"]["pull_request"]["user"]["login"] |
| get_console().print( |
| f"[warning]The actor: {actor} retrieved from GITHUB_CONTEXT's" |
| f" event.pull_request.user.login[/]" |
| ) |
| except Exception as e: |
| get_console().print(f"[warning]Exception when reading user login: {e}[/]") |
| get_console().print( |
| f"[info]Could not find the actor from pull request, " |
| f"falling back to the actor who triggered the PR: {actor}[/]" |
| ) |
| if actor in COMMITTERS and USE_PUBLIC_RUNNERS_LABEL not in self._pr_labels: |
| return RUNS_ON_SELF_HOSTED_RUNNER |
| return RUNS_ON_PUBLIC_RUNNER |
| |
| @cached_property |
| def mssql_parallelism(self) -> int: |
| # Limit parallelism for MSSQL to 1 for public runners due to race conditions generated there |
| return SELF_HOSTED_RUNNERS_CPU_COUNT if self.runs_on == RUNS_ON_SELF_HOSTED_RUNNER else 1 |