add OpenLineage configuration injection to SparkSubmitOperator

Signed-off-by: Maciej Obuchowski <maciej.obuchowski@datadoghq.com>
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 65e6604..f4068c8 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -282,6 +282,7 @@
   },
   "apache.spark": {
     "deps": [
+      "apache-airflow-providers-common-compat>=1.5.0",
       "apache-airflow>=2.9.0",
       "grpcio-status>=1.59.0",
       "pyspark>=3.1.3"
diff --git a/providers/apache/spark/README.rst b/providers/apache/spark/README.rst
index cc892e0..d822c5c 100644
--- a/providers/apache/spark/README.rst
+++ b/providers/apache/spark/README.rst
@@ -50,13 +50,14 @@
 Requirements
 ------------
 
-==================  ==================
-PIP package         Version required
-==================  ==================
-``apache-airflow``  ``>=2.9.0``
-``pyspark``         ``>=3.1.3``
-``grpcio-status``   ``>=1.59.0``
-==================  ==================
+==========================================  ==================
+PIP package                                 Version required
+==========================================  ==================
+``apache-airflow``                          ``>=2.9.0``
+``apache-airflow-providers-common-compat``  ``>=1.5.0``
+``pyspark``                                 ``>=3.1.3``
+``grpcio-status``                           ``>=1.59.0``
+==========================================  ==================
 
 Cross provider package dependencies
 -----------------------------------
diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml
index 00158c0..a3f8ddf 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -58,6 +58,7 @@
 # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
 dependencies = [
     "apache-airflow>=2.9.0",
+    "apache-airflow-providers-common-compat>=1.5.0",
     "pyspark>=3.1.3",
     "grpcio-status>=1.59.0",
 ]
@@ -68,9 +69,6 @@
 "cncf.kubernetes" = [
     "apache-airflow-providers-cncf-kubernetes>=7.4.0",
 ]
-"common.compat" = [
-    "apache-airflow-providers-common-compat"
-]
 
 [dependency-groups]
 dev = [
diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index a30a51c..0074188 100644
--- a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -125,10 +125,12 @@
                 "name": "pyspark",
             }
         ],
-        "dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3", "grpcio-status>=1.59.0"],
-        "optional-dependencies": {
-            "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"],
-            "common.compat": ["apache-airflow-providers-common-compat"],
-        },
+        "dependencies": [
+            "apache-airflow>=2.9.0",
+            "apache-airflow-providers-common-compat>=1.5.0",
+            "pyspark>=3.1.3",
+            "grpcio-status>=1.59.0",
+        ],
+        "optional-dependencies": {"cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"]},
         "devel-dependencies": [],
     }
diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 3ad4ff0..0ba57eb 100644
--- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -20,8 +20,13 @@
 from collections.abc import Sequence
 from typing import TYPE_CHECKING, Any
 
+from airflow.configuration import conf
 from airflow.models import BaseOperator
 from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+    inject_parent_job_information_into_spark_properties,
+    inject_transport_information_into_spark_properties,
+)
 from airflow.settings import WEB_COLORS
 
 if TYPE_CHECKING:
@@ -135,6 +140,12 @@
         yarn_queue: str | None = None,
         deploy_mode: str | None = None,
         use_krb5ccache: bool = False,
+        openlineage_inject_parent_job_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_parent_job_info", fallback=False
+        ),
+        openlineage_inject_transport_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_transport_info", fallback=False
+        ),
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -169,9 +180,17 @@
         self._hook: SparkSubmitHook | None = None
         self._conn_id = conn_id
         self._use_krb5ccache = use_krb5ccache
+        self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
+        self._openlineage_inject_transport_info = openlineage_inject_transport_info
 
     def execute(self, context: Context) -> None:
         """Call the SparkSubmitHook to run the provided spark job."""
+        if self._openlineage_inject_parent_job_info:
+            self.log.debug("Injecting OpenLineage parent job information into Spark properties.")
+            self.conf = inject_parent_job_information_into_spark_properties(self.conf, context)
+        if self._openlineage_inject_transport_info:
+            self.log.debug("Injecting OpenLineage transport information into Spark properties.")
+            self.conf = inject_transport_information_into_spark_properties(self.conf, context)
         if self._hook is None:
             self._hook = self._get_hook()
         self._hook.submit(self.application)
diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 94344d54..b339093 100644
--- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -17,7 +17,10 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from datetime import timedelta
+from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 
@@ -281,3 +284,179 @@
         assert task.application_args == "application_args"
         assert task.env_vars == "env_vars"
         assert task.properties_file == "properties_file"
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
+        # Given / When
+        from openlineage.client.transport.http import (
+            ApiKeyTokenProvider,
+            HttpCompression,
+            HttpConfig,
+            HttpTransport,
+        )
+
+        mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
+            config=HttpConfig(
+                url="http://localhost:5000",
+                endpoint="api/v2/lineage",
+                timeout=5050,
+                auth=ApiKeyTokenProvider({"api_key": "12345"}),
+                compression=HttpCompression.GZIP,
+                custom_headers={"X-OpenLineage-Custom-Header": "airflow"},
+            )
+        )
+        operator = SparkSubmitOperator(
+            task_id="spark_submit_job",
+            spark_binary="sparky",
+            dag=self.dag,
+            openlineage_inject_parent_job_info=False,
+            openlineage_inject_transport_info=True,
+            **self._config,
+        )
+        operator.execute(MagicMock())
+
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+            "spark.openlineage.transport.type": "http",
+            "spark.openlineage.transport.url": "http://localhost:5000",
+            "spark.openlineage.transport.endpoint": "api/v2/lineage",
+            "spark.openlineage.transport.timeoutInMillis": "5050000",
+            "spark.openlineage.transport.compression": "gzip",
+            "spark.openlineage.transport.auth.type": "api_key",
+            "spark.openlineage.transport.auth.apiKey": "Bearer 12345",
+            "spark.openlineage.transport.headers.X-OpenLineage-Custom-Header": "airflow",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
+        # Given / When
+        from openlineage.client.transport.composite import CompositeConfig, CompositeTransport
+
+        mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
+            CompositeConfig.from_dict(
+                {
+                    "transports": {
+                        "test1": {
+                            "type": "http",
+                            "url": "http://localhost:5000",
+                            "endpoint": "api/v2/lineage",
+                            "timeout": "5050",
+                            "auth": {
+                                "type": "api_key",
+                                "api_key": "12345",
+                            },
+                            "compression": "gzip",
+                            "custom_headers": {
+                                "X-OpenLineage-Custom-Header": "airflow",
+                            },
+                        },
+                        "test2": {
+                            "type": "http",
+                            "url": "https://example.com:1234",
+                        },
+                        "test3": {"type": "console"},
+                    }
+                }
+            )
+        )
+
+        mock_ti = MagicMock()
+        mock_ti.dag_id = "test_dag_id"
+        mock_ti.task_id = "spark_submit_job"
+        mock_ti.try_number = 1
+        mock_ti.dag_run.logical_date = DEFAULT_DATE
+        mock_ti.dag_run.run_after = DEFAULT_DATE
+        mock_ti.logical_date = DEFAULT_DATE
+        mock_ti.map_index = -1
+
+        operator = SparkSubmitOperator(
+            task_id="spark_submit_job",
+            spark_binary="sparky",
+            dag=self.dag,
+            openlineage_inject_parent_job_info=True,
+            openlineage_inject_transport_info=True,
+            **self._config,
+        )
+        operator.execute({"ti": mock_ti})
+
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+            "spark.openlineage.parentJobName": "test_dag_id.spark_submit_job",
+            "spark.openlineage.parentJobNamespace": "default",
+            "spark.openlineage.parentRunId": "01595753-6400-710b-8a12-9e978335a56d",
+            "spark.openlineage.transport.type": "composite",
+            "spark.openlineage.transport.continueOnFailure": "True",
+            "spark.openlineage.transport.transports.test1.type": "http",
+            "spark.openlineage.transport.transports.test1.url": "http://localhost:5000",
+            "spark.openlineage.transport.transports.test1.endpoint": "api/v2/lineage",
+            "spark.openlineage.transport.transports.test1.timeoutInMillis": "5050000",
+            "spark.openlineage.transport.transports.test1.auth.type": "api_key",
+            "spark.openlineage.transport.transports.test1.auth.apiKey": "Bearer 12345",
+            "spark.openlineage.transport.transports.test1.compression": "gzip",
+            "spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header": "airflow",
+            "spark.openlineage.transport.transports.test2.type": "http",
+            "spark.openlineage.transport.transports.test2.url": "https://example.com:1234",
+            "spark.openlineage.transport.transports.test2.endpoint": "api/v1/lineage",
+            "spark.openlineage.transport.transports.test2.timeoutInMillis": "5000",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_openlineage_composite_config_wrong_transport_to_spark(
+        self, mock_get_openlineage_listener, mock_get_hook, caplog
+    ):
+        # Given / When
+        from openlineage.client.transport.composite import CompositeConfig, CompositeTransport
+
+        mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
+            CompositeConfig.from_dict({"transports": {"test1": {"type": "console"}}})
+        )
+
+        with caplog.at_level(logging.INFO):
+            operator = SparkSubmitOperator(
+                task_id="spark_submit_job",
+                spark_binary="sparky",
+                dag=self.dag,
+                openlineage_inject_parent_job_info=False,
+                openlineage_inject_transport_info=True,
+                **self._config,
+            )
+            operator.execute(MagicMock())
+
+            assert (
+                "OpenLineage transport type `composite` does not contain http transport. Skipping injection of OpenLineage transport information into Spark properties."
+                in caplog.text
+            )
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_openlineage_simple_config_wrong_transport_to_spark(
+        self, mock_get_openlineage_listener, mock_get_hook, caplog
+    ):
+        # Given / When
+        from openlineage.client.transport.console import ConsoleConfig, ConsoleTransport
+
+        mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = ConsoleTransport(
+            config=ConsoleConfig()
+        )
+
+        with caplog.at_level(logging.INFO):
+            operator = SparkSubmitOperator(
+                task_id="spark_submit_job",
+                spark_binary="sparky",
+                dag=self.dag,
+                openlineage_inject_parent_job_info=False,
+                openlineage_inject_transport_info=True,
+                **self._config,
+            )
+            operator.execute(MagicMock())
+
+            assert "OpenLineage transport type `console` does not support automatic injection of OpenLineage transport information into Spark properties."
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+        }
diff --git a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
index c6a4313..aed0285 100644
--- a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
+++ b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
@@ -99,7 +99,7 @@
                             "url": tp.url,
                             "endpoint": tp.endpoint,
                             "timeoutInMillis": str(
-                                int(tp.timeout * 1000)
+                                int(tp.timeout) * 1000
                                 # convert to milliseconds, as required by Spark integration
                             ),
                         }
diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
index 9f0fef8..becb4bd 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
@@ -60,7 +60,7 @@
             "url": tp.url,
             "endpoint": tp.endpoint,
             "timeoutInMillis": str(
-                int(tp.timeout * 1000)  # convert to milliseconds, as required by Spark integration
+                int(tp.timeout) * 1000  # convert to milliseconds, as required by Spark integration
             ),
         }
         if hasattr(tp, "compression") and tp.compression: