blob: 5edf0b1a7ad2600a4f4d21dc9078b5599d2b346f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tempfile
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from typing import List, Optional, Sequence
from unittest import mock
import pytest
from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS
from airflow.providers.google.cloud.utils.credentials_provider import provide_gcp_conn_and_credentials
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY
from tests.test_utils import AIRFLOW_MAIN_FOLDER
from tests.test_utils.system_tests_class import SystemTest
from tests.utils.logging_command_executor import get_executor
CLOUD_DAG_FOLDER = os.path.join(
AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "cloud", "example_dags"
)
MARKETING_DAG_FOLDER = os.path.join(
AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "marketing_platform", "example_dags"
)
GSUITE_DAG_FOLDER = os.path.join(
AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "suite", "example_dags"
)
FIREBASE_DAG_FOLDER = os.path.join(
AIRFLOW_MAIN_FOLDER, "airflow", "providers", "google", "firebase", "example_dags"
)
POSTGRES_LOCAL_EXECUTOR = os.path.join(
AIRFLOW_MAIN_FOLDER, "tests", "test_utils", "postgres_local_executor.cfg"
)
def resolve_full_gcp_key_path(key: str) -> str:
"""
Returns path full path to provided GCP key.
:param key: Name of the GCP key, for example ``my_service.json``
:type key: str
:returns: Full path to the key
"""
path = os.environ.get("CREDENTIALS_DIR", "/files/airflow-breeze-config/keys")
key = os.path.join(path, key)
return key
@contextmanager
def provide_gcp_context(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
):
"""
Context manager that provides:
- GCP credentials for application supporting `Application Default Credentials (ADC)
strategy <https://cloud.google.com/docs/authentication/production>`__.
- temporary value of ``AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`` connection
- the ``gcloud`` config directory isolated from user configuration
Moreover it resolves full path to service keys so user can pass ``myservice.json``
as ``key_file_path``.
:param key_file_path: Path to file with GCP credentials .json file.
:type key_file_path: str
:param scopes: OAuth scopes for the connection
:type scopes: Sequence
:param project_id: The id of GCP project for the connection.
:type project_id: str
"""
key_file_path = resolve_full_gcp_key_path(key_file_path) # type: ignore
with provide_gcp_conn_and_credentials(key_file_path, scopes, project_id), \
tempfile.TemporaryDirectory() as gcloud_config_tmp, \
mock.patch.dict('os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}):
executor = get_executor()
if project_id:
executor.execute_cmd([
"gcloud", "config", "set", "core/project", project_id
])
if key_file_path:
executor.execute_cmd([
"gcloud", "auth", "activate-service-account", f"--key-file={key_file_path}",
])
yield
@pytest.mark.system("google")
class GoogleSystemTest(SystemTest):
@staticmethod
def _project_id():
return os.environ.get("GCP_PROJECT_ID")
@staticmethod
def _service_key():
return os.environ.get(CREDENTIALS)
@classmethod
def execute_with_ctx(cls, cmd: List[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None):
"""
Executes command with context created by provide_gcp_context and activated
service key.
"""
executor = get_executor()
current_project_id = project_id or cls._project_id()
with provide_gcp_context(key, project_id=current_project_id, scopes=scopes):
executor.execute_cmd(cmd=cmd)
@classmethod
def create_gcs_bucket(cls, name: str, location: Optional[str] = None) -> None:
bucket_name = f"gs://{name}" if not name.startswith("gs://") else name
cmd = ["gsutil", "mb"]
if location:
cmd += ["-c", "regional", "-l", location]
cmd += [bucket_name]
cls.execute_with_ctx(cmd, key=GCP_GCS_KEY)
@classmethod
def delete_gcs_bucket(cls, name: str):
bucket_name = f"gs://{name}" if not name.startswith("gs://") else name
cmd = ["gsutil", "-m", "rm", "-r", bucket_name]
cls.execute_with_ctx(cmd, key=GCP_GCS_KEY)
@classmethod
def upload_to_gcs(cls, source_uri: str, target_uri: str):
cls.execute_with_ctx(
["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY
)
@classmethod
def upload_content_to_gcs(cls, lines: str, bucket: str, filename: str):
bucket_name = f"gs://{bucket}" if not bucket.startswith("gs://") else bucket
with TemporaryDirectory(prefix="airflow-gcp") as tmp_dir:
tmp_path = os.path.join(tmp_dir, filename)
with open(tmp_path, "w") as file:
file.writelines(lines)
file.flush()
os.chmod(tmp_path, 555)
cls.upload_to_gcs(tmp_path, bucket_name)
@classmethod
def get_project_number(cls, project_id: str) -> str:
cmd = ['gcloud', 'projects', 'describe', project_id, '--format', 'value(projectNumber)']
return cls.check_output(cmd).decode("utf-8").strip()
@classmethod
def grant_bucket_access(cls, bucket: str, account_email: str):
bucket_name = f"gs://{bucket}" if not bucket.startswith("gs://") else bucket
cls.execute_cmd(
[
"gsutil",
"iam",
"ch",
"serviceAccount:%s:admin" % account_email,
bucket_name,
]
)