add OpenLineage configuration injection to SparkSubmitOperator
Signed-off-by: Maciej Obuchowski <maciej.obuchowski@datadoghq.com>
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 65e6604..f4068c8 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -282,6 +282,7 @@
},
"apache.spark": {
"deps": [
+ "apache-airflow-providers-common-compat>=1.5.0",
"apache-airflow>=2.9.0",
"grpcio-status>=1.59.0",
"pyspark>=3.1.3"
diff --git a/providers/apache/spark/README.rst b/providers/apache/spark/README.rst
index cc892e0..d822c5c 100644
--- a/providers/apache/spark/README.rst
+++ b/providers/apache/spark/README.rst
@@ -50,13 +50,14 @@
Requirements
------------
-================== ==================
-PIP package Version required
-================== ==================
-``apache-airflow`` ``>=2.9.0``
-``pyspark`` ``>=3.1.3``
-``grpcio-status`` ``>=1.59.0``
-================== ==================
+========================================== ==================
+PIP package Version required
+========================================== ==================
+``apache-airflow`` ``>=2.9.0``
+``apache-airflow-providers-common-compat`` ``>=1.5.0``
+``pyspark`` ``>=3.1.3``
+``grpcio-status`` ``>=1.59.0``
+========================================== ==================
Cross provider package dependencies
-----------------------------------
diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml
index 00158c0..a3f8ddf 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -58,6 +58,7 @@
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.9.0",
+ "apache-airflow-providers-common-compat>=1.5.0",
"pyspark>=3.1.3",
"grpcio-status>=1.59.0",
]
@@ -68,9 +69,6 @@
"cncf.kubernetes" = [
"apache-airflow-providers-cncf-kubernetes>=7.4.0",
]
-"common.compat" = [
- "apache-airflow-providers-common-compat"
-]
[dependency-groups]
dev = [
diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index a30a51c..0074188 100644
--- a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -125,10 +125,12 @@
"name": "pyspark",
}
],
- "dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3", "grpcio-status>=1.59.0"],
- "optional-dependencies": {
- "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"],
- "common.compat": ["apache-airflow-providers-common-compat"],
- },
+ "dependencies": [
+ "apache-airflow>=2.9.0",
+ "apache-airflow-providers-common-compat>=1.5.0",
+ "pyspark>=3.1.3",
+ "grpcio-status>=1.59.0",
+ ],
+ "optional-dependencies": {"cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"]},
"devel-dependencies": [],
}
diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 3ad4ff0..0ba57eb 100644
--- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -20,8 +20,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
+from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+ inject_parent_job_information_into_spark_properties,
+ inject_transport_information_into_spark_properties,
+)
from airflow.settings import WEB_COLORS
if TYPE_CHECKING:
@@ -135,6 +140,12 @@
yarn_queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
+ openlineage_inject_parent_job_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_parent_job_info", fallback=False
+ ),
+ openlineage_inject_transport_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_transport_info", fallback=False
+ ),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -169,9 +180,17 @@
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
+ self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
+ self._openlineage_inject_transport_info = openlineage_inject_transport_info
def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
+ if self._openlineage_inject_parent_job_info:
+ self.log.debug("Injecting OpenLineage parent job information into Spark properties.")
+ self.conf = inject_parent_job_information_into_spark_properties(self.conf, context)
+ if self._openlineage_inject_transport_info:
+ self.log.debug("Injecting OpenLineage transport information into Spark properties.")
+ self.conf = inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
self._hook.submit(self.application)
diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 94344d54..b339093 100644
--- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -17,7 +17,10 @@
# under the License.
from __future__ import annotations
+import logging
from datetime import timedelta
+from unittest import mock
+from unittest.mock import MagicMock
import pytest
@@ -281,3 +284,179 @@
assert task.application_args == "application_args"
assert task.env_vars == "env_vars"
assert task.properties_file == "properties_file"
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+ @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
+ # Given / When
+ from openlineage.client.transport.http import (
+ ApiKeyTokenProvider,
+ HttpCompression,
+ HttpConfig,
+ HttpTransport,
+ )
+
+ mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
+ config=HttpConfig(
+ url="http://localhost:5000",
+ endpoint="api/v2/lineage",
+ timeout=5050,
+ auth=ApiKeyTokenProvider({"api_key": "12345"}),
+ compression=HttpCompression.GZIP,
+ custom_headers={"X-OpenLineage-Custom-Header": "airflow"},
+ )
+ )
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ "spark.openlineage.transport.type": "http",
+ "spark.openlineage.transport.url": "http://localhost:5000",
+ "spark.openlineage.transport.endpoint": "api/v2/lineage",
+ "spark.openlineage.transport.timeoutInMillis": "5050000",
+ "spark.openlineage.transport.compression": "gzip",
+ "spark.openlineage.transport.auth.type": "api_key",
+ "spark.openlineage.transport.auth.apiKey": "Bearer 12345",
+ "spark.openlineage.transport.headers.X-OpenLineage-Custom-Header": "airflow",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+ @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
+ # Given / When
+ from openlineage.client.transport.composite import CompositeConfig, CompositeTransport
+
+ mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
+ CompositeConfig.from_dict(
+ {
+ "transports": {
+ "test1": {
+ "type": "http",
+ "url": "http://localhost:5000",
+ "endpoint": "api/v2/lineage",
+ "timeout": "5050",
+ "auth": {
+ "type": "api_key",
+ "api_key": "12345",
+ },
+ "compression": "gzip",
+ "custom_headers": {
+ "X-OpenLineage-Custom-Header": "airflow",
+ },
+ },
+ "test2": {
+ "type": "http",
+ "url": "https://example.com:1234",
+ },
+ "test3": {"type": "console"},
+ }
+ }
+ )
+ )
+
+ mock_ti = MagicMock()
+ mock_ti.dag_id = "test_dag_id"
+ mock_ti.task_id = "spark_submit_job"
+ mock_ti.try_number = 1
+ mock_ti.dag_run.logical_date = DEFAULT_DATE
+ mock_ti.dag_run.run_after = DEFAULT_DATE
+ mock_ti.logical_date = DEFAULT_DATE
+ mock_ti.map_index = -1
+
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=True,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute({"ti": mock_ti})
+
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ "spark.openlineage.parentJobName": "test_dag_id.spark_submit_job",
+ "spark.openlineage.parentJobNamespace": "default",
+ "spark.openlineage.parentRunId": "01595753-6400-710b-8a12-9e978335a56d",
+ "spark.openlineage.transport.type": "composite",
+ "spark.openlineage.transport.continueOnFailure": "True",
+ "spark.openlineage.transport.transports.test1.type": "http",
+ "spark.openlineage.transport.transports.test1.url": "http://localhost:5000",
+ "spark.openlineage.transport.transports.test1.endpoint": "api/v2/lineage",
+ "spark.openlineage.transport.transports.test1.timeoutInMillis": "5050000",
+ "spark.openlineage.transport.transports.test1.auth.type": "api_key",
+ "spark.openlineage.transport.transports.test1.auth.apiKey": "Bearer 12345",
+ "spark.openlineage.transport.transports.test1.compression": "gzip",
+ "spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header": "airflow",
+ "spark.openlineage.transport.transports.test2.type": "http",
+ "spark.openlineage.transport.transports.test2.url": "https://example.com:1234",
+ "spark.openlineage.transport.transports.test2.endpoint": "api/v1/lineage",
+ "spark.openlineage.transport.transports.test2.timeoutInMillis": "5000",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+ @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_openlineage_composite_config_wrong_transport_to_spark(
+ self, mock_get_openlineage_listener, mock_get_hook, caplog
+ ):
+ # Given / When
+ from openlineage.client.transport.composite import CompositeConfig, CompositeTransport
+
+ mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
+ CompositeConfig.from_dict({"transports": {"test1": {"type": "console"}}})
+ )
+
+ with caplog.at_level(logging.INFO):
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert (
+ "OpenLineage transport type `composite` does not contain http transport. Skipping injection of OpenLineage transport information into Spark properties."
+ in caplog.text
+ )
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+ @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_openlineage_simple_config_wrong_transport_to_spark(
+ self, mock_get_openlineage_listener, mock_get_hook, caplog
+ ):
+ # Given / When
+ from openlineage.client.transport.console import ConsoleConfig, ConsoleTransport
+
+ mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = ConsoleTransport(
+ config=ConsoleConfig()
+ )
+
+ with caplog.at_level(logging.INFO):
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert "OpenLineage transport type `console` does not support automatic injection of OpenLineage transport information into Spark properties."
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ }
diff --git a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
index c6a4313..aed0285 100644
--- a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
+++ b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
@@ -99,7 +99,7 @@
"url": tp.url,
"endpoint": tp.endpoint,
"timeoutInMillis": str(
- int(tp.timeout * 1000)
+ int(tp.timeout) * 1000
# convert to milliseconds, as required by Spark integration
),
}
diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
index 9f0fef8..becb4bd 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
@@ -60,7 +60,7 @@
"url": tp.url,
"endpoint": tp.endpoint,
"timeoutInMillis": str(
- int(tp.timeout * 1000) # convert to milliseconds, as required by Spark integration
+ int(tp.timeout) * 1000 # convert to milliseconds, as required by Spark integration
),
}
if hasattr(tp, "compression") and tp.compression: