[SPARK-50699][PYTHON] Parse and generate DDL string with a specified session
### What changes were proposed in this pull request?
Parse and generate DDL string with a specified session
### Why are the changes needed?
In `_parse_datatype_string` and `toDDL`, a `SparkSession` or `SparkContext` is always needed.
In most cases, the session is already present, so we can avoid creating or fetching the activate session.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49331 from zhengruifeng/py_session_ddl_json.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 9ed4699..42a3685 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -75,6 +75,7 @@
if TYPE_CHECKING:
from pyspark.accumulators import AccumulatorParam
+ from pyspark.sql.types import DataType, StructType
__all__ = ["SparkContext"]
@@ -2623,6 +2624,16 @@
messageParameters={},
)
+ def _to_ddl(self, struct: "StructType") -> str:
+ assert self._jvm is not None
+ return self._jvm.PythonSQLUtils.jsonToDDL(struct.json())
+
+ def _parse_ddl(self, ddl: str) -> "DataType":
+ from pyspark.sql.types import _parse_datatype_json_string
+
+ assert self._jvm is not None
+ return _parse_datatype_json_string(self._jvm.PythonSQLUtils.ddlToJson(ddl))
+
def _test() -> None:
import doctest
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 35b9654..8682057 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -7292,8 +7292,6 @@
4 1 True 1.0
5 2 False 2.0
"""
- from pyspark.sql.types import _parse_datatype_string
-
include_list: List[str]
if not is_list_like(include):
include_list = [cast(str, include)] if include is not None else []
@@ -7320,14 +7318,14 @@
include_spark_type = []
for inc in include_list:
try:
- include_spark_type.append(_parse_datatype_string(inc))
+ include_spark_type.append(self._internal.spark_frame._session._parse_ddl(inc))
except BaseException:
pass
exclude_spark_type = []
for exc in exclude_list:
try:
- exclude_spark_type.append(_parse_datatype_string(exc))
+ exclude_spark_type.append(self._internal.spark_frame._session._parse_ddl(exc))
except BaseException:
pass
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 33956c8..3d8f0ece 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -54,7 +54,7 @@
from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.util import is_remote_only
-from pyspark.sql.types import Row, StructType, _create_row, _parse_datatype_string
+from pyspark.sql.types import Row, StructType, _create_row
from pyspark.sql.dataframe import (
DataFrame as ParentDataFrame,
DataFrameNaFunctions as ParentDataFrameNaFunctions,
@@ -2037,7 +2037,7 @@
_validate_pandas_udf(func, evalType)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index 006af87..11adc88 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,7 +35,7 @@
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
-from pyspark.sql.types import NumericType, StructType, _parse_datatype_string
+from pyspark.sql.types import NumericType, StructType
import pyspark.sql.connect.plan as plan
from pyspark.sql.column import Column
@@ -295,7 +295,7 @@
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -369,7 +369,7 @@
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -414,7 +414,7 @@
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -445,7 +445,7 @@
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 925eaae..3f1663d 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1111,6 +1111,16 @@
return creator, (self._session_id,)
+ def _to_ddl(self, struct: StructType) -> str:
+ ddl = self._client._analyze(method="json_to_ddl", json_string=struct.json()).ddl_string
+ assert ddl is not None
+ return ddl
+
+ def _parse_ddl(self, ddl: str) -> DataType:
+ dt = self._client._analyze(method="ddl_parse", ddl_string=ddl).parsed
+ assert dt is not None
+ return dt
+
SparkSession.__doc__ = PySparkSession.__doc__
diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py
index bd12b41..343a68b 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -36,7 +36,7 @@
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
-from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
@@ -348,9 +348,9 @@
]
if isinstance(outputStructType, str):
- outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType, self._df._session._parse_ddl(outputStructType))
if isinstance(stateStructType, str):
- stateStructType = cast(StructType, _parse_datatype_string(stateStructType))
+ stateStructType = cast(StructType, self._df._session._parse_ddl(stateStructType))
udf = pandas_udf(
func, # type: ignore[call-overload]
@@ -502,7 +502,7 @@
if initialState is not None:
assert isinstance(initialState, GroupedData)
if isinstance(outputStructType, str):
- outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType, self._df._session._parse_ddl(outputStructType))
def handle_pre_init(
statefulProcessorApiClient: StatefulProcessorApiClient,
@@ -681,7 +681,7 @@
return result
if isinstance(outputStructType, str):
- outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType, self._df._session._parse_ddl(outputStructType))
df = self._df
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 00fa604..f3a1639 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -58,7 +58,6 @@
_has_nulltype,
_merge_type,
_create_converter,
- _parse_datatype_string,
_from_numpy_type,
)
from pyspark.errors.exceptions.captured import install_exception_handler
@@ -1501,7 +1500,7 @@
)
if isinstance(schema, str):
- schema = cast(Union[AtomicType, StructType, str], _parse_datatype_string(schema))
+ schema = cast(Union[AtomicType, StructType, str], self._parse_ddl(schema))
elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with StructField names
schema = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]
@@ -2338,6 +2337,12 @@
"""
self._jsparkSession.clearTags()
+ def _to_ddl(self, struct: StructType) -> str:
+ return self._sc._to_ddl(struct)
+
+ def _parse_ddl(self, ddl: str) -> DataType:
+ return self._sc._parse_ddl(ddl)
+
def _test() -> None:
import os
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 93ac665..f40a8bf 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1563,16 +1563,9 @@
session = SparkSession.getActiveSession()
assert session is not None
- return session._client._analyze( # type: ignore[return-value]
- method="json_to_ddl", json_string=self.json()
- ).ddl_string
-
+ return session._to_ddl(self)
else:
- from py4j.java_gateway import JVMView
-
- sc = get_active_spark_context()
- assert sc._jvm is not None
- return cast(JVMView, sc._jvm).PythonSQLUtils.jsonToDDL(self.json())
+ return get_active_spark_context()._to_ddl(self)
class VariantType(AtomicType):
@@ -1907,18 +1900,9 @@
if is_remote():
from pyspark.sql.connect.session import SparkSession
- return cast(
- DataType,
- SparkSession.active()._client._analyze(method="ddl_parse", ddl_string=s).parsed,
- )
-
+ return SparkSession.active()._parse_ddl(s)
else:
- from py4j.java_gateway import JVMView
-
- sc = get_active_spark_context()
- return _parse_datatype_json_string(
- cast(JVMView, sc._jvm).org.apache.spark.sql.api.python.PythonSQLUtils.ddlToJson(s)
- )
+ return get_active_spark_context()._parse_ddl(s)
def _parse_datatype_json_string(json_string: str) -> DataType: