Bugfix yaml parsing for GKEStartKueueInsideClusterOperator (#39234)
* Bugfix yaml parsing for GKEStartKueueInsideClusterOperator
* Unit tests
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 006e7b3..f876622 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -19,7 +19,6 @@
from __future__ import annotations
-import re
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
@@ -566,17 +565,10 @@
def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
"""Download content of YAML file and separate it into several dictionaries."""
response = requests.get(kueue_yaml_url, allow_redirects=True)
- yaml_dicts = []
- if response.status_code == 200:
- yaml_data = response.text
- documents = re.split(r"---\n", yaml_data)
-
- for document in documents:
- document_dict = yaml.safe_load(document)
- yaml_dicts.append(document_dict)
- else:
+ if response.status_code != 200:
raise AirflowException("Was not able to read the yaml file from given URL")
- return yaml_dicts
+
+ return list(yaml.safe_load_all(response.text))
def execute(self, context: Context):
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index a6d0de5..05c7bf2 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -129,6 +129,7 @@
requests:
storage: 5Gi
"""
+KUEUE_YAML_URL = "http://test-url/config.yaml"
class TestGoogleCloudPlatformContainerOperator:
@@ -641,6 +642,27 @@
assert hook.gcp_conn_id == "test_conn"
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests")
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.yaml")
+ def test_get_yaml_content_from_file(self, mock_yaml, mock_requests):
+ yaml_content_expected = [mock.MagicMock(), mock.MagicMock()]
+ mock_yaml.safe_load_all.return_value = yaml_content_expected
+ response_text_expected = "response test expected"
+ mock_requests.get.return_value = mock.MagicMock(status_code=200, text=response_text_expected)
+
+ yaml_content_actual = GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL)
+
+ assert yaml_content_actual == yaml_content_expected
+ mock_requests.get.assert_called_once_with(KUEUE_YAML_URL, allow_redirects=True)
+ mock_yaml.safe_load_all.assert_called_once_with(response_text_expected)
+
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests")
+ def test_get_yaml_content_from_file_exception(self, mock_requests):
+ mock_requests.get.return_value = mock.MagicMock(status_code=400)
+
+ with pytest.raises(AirflowException):
+ GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL)
+
class TestGKEPodOperatorAsync:
def setup_method(self):
diff --git a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
index 0ed6653..d7af95f 100644
--- a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
+++ b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
@@ -103,7 +103,7 @@
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
- kueue_version="v0.5.1",
+ kueue_version="v0.6.2",
)
# [END howto_operator_gke_install_kueue]