| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import os |
| import tempfile |
| from contextlib import contextmanager |
| from tempfile import TemporaryDirectory |
| from typing import Sequence |
| from unittest import mock |
| |
| import pytest |
| from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS |
| |
| from airflow.providers.google.cloud.utils.credentials_provider import provide_gcp_conn_and_credentials |
| from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY, GCP_SECRET_MANAGER_KEY |
| from tests.test_utils import AIRFLOW_MAIN_FOLDER |
| from tests.test_utils.logging_command_executor import CommandExecutor |
| from tests.test_utils.system_tests_class import SystemTest |
| |
| CLOUD_DAG_FOLDER = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "cloud", "example_dags" |
| ) |
| MARKETING_DAG_FOLDER = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "marketing_platform", "example_dags" |
| ) |
| GSUITE_DAG_FOLDER = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "suite", "example_dags" |
| ) |
| FIREBASE_DAG_FOLDER = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "firebase", "example_dags" |
| ) |
| LEVELDB_DAG_FOLDER = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "leveldb", "example_dags" |
| ) |
| POSTGRES_LOCAL_EXECUTOR = os.path.join( |
| AIRFLOW_MAIN_FOLDER, "tests", "test_utils", "postgres_local_executor.cfg" |
| ) |
| |
| |
| def resolve_full_gcp_key_path(key: str) -> str: |
| """ |
| Returns path full path to provided GCP key. |
| |
| :param key: Name of the GCP key, for example ``my_service.json`` |
| :returns: Full path to the key |
| """ |
| path = os.environ.get("CREDENTIALS_DIR", "/files/airflow-breeze-config/keys") |
| key = os.path.join(path, key) |
| return key |
| |
| |
| @contextmanager |
| def provide_gcp_context( |
| key_file_path: str | None = None, |
| scopes: Sequence | None = None, |
| project_id: str | None = None, |
| ): |
| """ |
| Context manager that provides: |
| |
| - GCP credentials for application supporting `Application Default Credentials (ADC) |
| strategy <https://cloud.google.com/docs/authentication/production>`__. |
| - temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` variable |
| - the ``gcloud`` config directory isolated from user configuration |
| |
| Moreover it resolves full path to service keys so user can pass ``myservice.json`` |
| as ``key_file_path``. |
| |
| :param key_file_path: Path to file with GCP credentials .json file. |
| :param scopes: OAuth scopes for the connection |
| :param project_id: The id of GCP project for the connection. |
| Default: ``os.environ["GCP_PROJECT_ID"]`` or None |
| """ |
| key_file_path = resolve_full_gcp_key_path(key_file_path) # type: ignore |
| if project_id is None: |
| project_id = os.environ.get("GCP_PROJECT_ID") |
| with provide_gcp_conn_and_credentials( |
| key_file_path, scopes, project_id |
| ), tempfile.TemporaryDirectory() as gcloud_config_tmp, mock.patch.dict( |
| "os.environ", {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp} |
| ): |
| executor = CommandExecutor() |
| |
| if key_file_path: |
| executor.execute_cmd( |
| [ |
| "gcloud", |
| "auth", |
| "activate-service-account", |
| f"--key-file={key_file_path}", |
| ] |
| ) |
| if project_id: |
| executor.execute_cmd(["gcloud", "config", "set", "core/project", project_id]) |
| yield |
| |
| |
| @contextmanager |
| @provide_gcp_context(GCP_GCS_KEY) |
| def provide_gcs_bucket(bucket_name: str): |
| GoogleSystemTest.create_gcs_bucket(bucket_name) |
| yield |
| GoogleSystemTest.delete_gcs_bucket(bucket_name) |
| |
| |
| @pytest.mark.system("google") |
| class GoogleSystemTest(SystemTest): |
| @staticmethod |
| def execute_cmd(*args, **kwargs): |
| executor = CommandExecutor() |
| return executor.execute_cmd(*args, **kwargs) |
| |
| @staticmethod |
| def _project_id(): |
| return os.environ.get("GCP_PROJECT_ID") |
| |
| @staticmethod |
| def _service_key(): |
| return os.environ.get(CREDENTIALS) |
| |
| @classmethod |
| def execute_with_ctx( |
| cls, cmd: list[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None, silent: bool = False |
| ): |
| """ |
| Executes command with context created by provide_gcp_context and activated |
| service key. |
| """ |
| current_project_id = project_id or cls._project_id() |
| with provide_gcp_context(key, project_id=current_project_id, scopes=scopes): |
| cls.execute_cmd(cmd=cmd, silent=silent) |
| |
| @classmethod |
| def create_gcs_bucket(cls, name: str, location: str | None = None) -> None: |
| bucket_name = f"gs://{name}" if not name.startswith("gs://") else name |
| cmd = ["gsutil", "mb"] |
| if location: |
| cmd += ["-c", "regional", "-l", location] |
| cmd += [bucket_name] |
| cls.execute_with_ctx(cmd, key=GCP_GCS_KEY) |
| |
| @classmethod |
| def delete_gcs_bucket(cls, name: str): |
| bucket_name = f"gs://{name}" if not name.startswith("gs://") else name |
| cmd = ["gsutil", "-m", "rm", "-r", bucket_name] |
| cls.execute_with_ctx(cmd, key=GCP_GCS_KEY) |
| |
| @classmethod |
| def upload_to_gcs(cls, source_uri: str, target_uri: str): |
| cls.execute_with_ctx(["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY) |
| |
| @classmethod |
| def upload_content_to_gcs(cls, lines: str, bucket: str, filename: str): |
| bucket_name = f"gs://{bucket}" if not bucket.startswith("gs://") else bucket |
| with TemporaryDirectory(prefix="airflow-gcp") as tmp_dir: |
| tmp_path = os.path.join(tmp_dir, filename) |
| tmp_dir_path = os.path.dirname(tmp_path) |
| if tmp_dir_path: |
| os.makedirs(tmp_dir_path, exist_ok=True) |
| with open(tmp_path, "w") as file: |
| file.writelines(lines) |
| file.flush() |
| os.chmod(tmp_path, 777) |
| cls.upload_to_gcs(tmp_path, bucket_name) |
| |
| @classmethod |
| def get_project_number(cls, project_id: str) -> str: |
| cmd = ["gcloud", "projects", "describe", project_id, "--format", "value(projectNumber)"] |
| return cls.check_output(cmd).decode("utf-8").strip() |
| |
| @classmethod |
| def grant_bucket_access(cls, bucket: str, account_email: str): |
| bucket_name = f"gs://{bucket}" if not bucket.startswith("gs://") else bucket |
| cls.execute_cmd( |
| [ |
| "gsutil", |
| "iam", |
| "ch", |
| f"serviceAccount:{account_email}:admin", |
| bucket_name, |
| ] |
| ) |
| |
| @classmethod |
| def delete_secret(cls, name: str, silent: bool = False): |
| cmd = ["gcloud", "secrets", "delete", name, "--project", GoogleSystemTest._project_id(), "--quiet"] |
| cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY, silent=silent) |
| |
| @classmethod |
| def create_secret(cls, name: str, value: str): |
| with tempfile.NamedTemporaryFile() as tmp: |
| tmp.write(value.encode("UTF-8")) |
| tmp.flush() |
| cmd = [ |
| "gcloud", |
| "secrets", |
| "create", |
| name, |
| "--replication-policy", |
| "automatic", |
| "--project", |
| GoogleSystemTest._project_id(), |
| "--data-file", |
| tmp.name, |
| ] |
| cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY) |
| |
| @classmethod |
| def update_secret(cls, name: str, value: str): |
| with tempfile.NamedTemporaryFile() as tmp: |
| tmp.write(value.encode("UTF-8")) |
| tmp.flush() |
| cmd = [ |
| "gcloud", |
| "secrets", |
| "versions", |
| "add", |
| name, |
| "--project", |
| GoogleSystemTest._project_id(), |
| "--data-file", |
| tmp.name, |
| ] |
| cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY) |