| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import uuid |
| from pyspark.sql.connect.utils import check_dependencies |
| |
| check_dependencies(__name__) |
| |
| import json |
| import threading |
| import os |
| import warnings |
| from collections.abc import Callable, Sized |
| import functools |
| from threading import RLock |
| from typing import ( |
| Optional, |
| Any, |
| Union, |
| Dict, |
| List, |
| Tuple, |
| Set, |
| cast, |
| overload, |
| Iterable, |
| Mapping, |
| TYPE_CHECKING, |
| ClassVar, |
| ) |
| |
| import numpy as np |
| import pandas as pd |
| import pyarrow as pa |
| from pandas.api.types import ( # type: ignore[attr-defined] |
| is_datetime64_dtype, |
| is_timedelta64_dtype, |
| ) |
| import urllib |
| |
| from pyspark.sql.connect.dataframe import DataFrame |
| from pyspark.sql.dataframe import DataFrame as ParentDataFrame |
| from pyspark.sql.connect.logging import logger |
| from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder |
| from pyspark.sql.connect.conf import RuntimeConf |
| from pyspark.sql.connect.plan import ( |
| SQL, |
| Range, |
| LocalRelation, |
| LogicalPlan, |
| CachedLocalRelation, |
| CachedRelation, |
| CachedRemoteRelation, |
| SubqueryAlias, |
| ) |
| from pyspark.sql.connect.functions import builtin as F |
| from pyspark.sql.connect.profiler import ProfilerCollector |
| from pyspark.sql.connect.readwriter import DataFrameReader |
| from pyspark.sql.connect.streaming.readwriter import DataStreamReader |
| from pyspark.sql.connect.streaming.query import StreamingQueryManager |
| from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer |
| from pyspark.sql.pandas.types import ( |
| to_arrow_schema, |
| to_arrow_type, |
| _deduplicate_field_names, |
| from_arrow_schema, |
| from_arrow_type, |
| _check_arrow_table_timestamps_localize, |
| ) |
| from pyspark.sql.profiler import Profile |
| from pyspark.sql.session import classproperty, SparkSession as PySparkSession |
| from pyspark.sql.types import ( |
| _infer_schema, |
| _has_nulltype, |
| _merge_type, |
| Row, |
| DataType, |
| DayTimeIntervalType, |
| StructType, |
| AtomicType, |
| TimestampType, |
| MapType, |
| StringType, |
| ) |
| from pyspark.sql.utils import to_str |
| from pyspark.errors import ( |
| PySparkAttributeError, |
| PySparkNotImplementedError, |
| PySparkRuntimeError, |
| PySparkValueError, |
| PySparkTypeError, |
| PySparkAssertionError, |
| ) |
| |
| if TYPE_CHECKING: |
| import pyspark.sql.connect.proto as pb2 |
| from pyspark.sql.connect._typing import OptionalPrimitiveType |
| from pyspark.sql.connect.catalog import Catalog |
| from pyspark.sql.connect.udf import UDFRegistration |
| from pyspark.sql.connect.udtf import UDTFRegistration |
| from pyspark.sql.connect.tvf import TableValuedFunction |
| from pyspark.sql.connect.shell.progress import ProgressHandler |
| from pyspark.sql.connect.datasource import DataSourceRegistration |
| |
| |
| class SparkSession: |
| # The active SparkSession for the current thread |
| _active_session: ClassVar[threading.local] = threading.local() |
| # Reference to the root SparkSession |
| _default_session: ClassVar[Optional["SparkSession"]] = None |
| _lock: ClassVar[RLock] = RLock() |
| |
| class Builder: |
| """Builder for :class:`SparkSession`.""" |
| |
| _lock = RLock() |
| |
| def __init__(self) -> None: |
| self._options: Dict[str, Any] = {} |
| self._channel_builder: Optional[DefaultChannelBuilder] = None |
| self._hook_factories: list["Callable[[SparkSession], SparkSession.Hook]"] = [] |
| |
| @overload |
| def config(self, key: str, value: Any) -> "SparkSession.Builder": |
| ... |
| |
| @overload |
| def config(self, *, map: Dict[str, "OptionalPrimitiveType"]) -> "SparkSession.Builder": |
| ... |
| |
| def config( |
| self, |
| key: Optional[str] = None, |
| value: Optional[Any] = None, |
| *, |
| map: Optional[Dict[str, "OptionalPrimitiveType"]] = None, |
| ) -> "SparkSession.Builder": |
| with self._lock: |
| if map is not None: |
| for k, v in map.items(): |
| self._options[k] = to_str(v) |
| else: |
| self._options[cast(str, key)] = to_str(value) |
| return self |
| |
| def master(self, master: str) -> "SparkSession.Builder": |
| return self |
| |
| def appName(self, name: str) -> "SparkSession.Builder": |
| return self.config("spark.app.name", name) |
| |
| def remote(self, location: str = "sc://localhost") -> "SparkSession.Builder": |
| return self.config("spark.remote", location) |
| |
| def channelBuilder(self, channelBuilder: DefaultChannelBuilder) -> "SparkSession.Builder": |
| """Uses custom :class:`ChannelBuilder` implementation, when there is a need |
| to customize the behavior for creation of GRPC connections. |
| |
| .. versionadded:: 3.5.0 |
| |
| An example to use this class looks like this: |
| |
| .. code-block:: python |
| |
| from pyspark.sql.connect import SparkSession, ChannelBuilder |
| |
| class CustomChannelBuilder(ChannelBuilder): |
| ... |
| |
| custom_channel_builder = CustomChannelBuilder(...) |
| spark = SparkSession.builder().channelBuilder(custom_channel_builder).getOrCreate() |
| |
| Returns |
| ------- |
| :class:`SparkSession.Builder` |
| """ |
| with self._lock: |
| # self._channel_builder is a separate field, because it may hold the state |
| # and cannot be serialized with to_str() |
| self._channel_builder = channelBuilder |
| return self |
| |
| def _registerHook( |
| self, hook_factory: "Callable[[SparkSession], SparkSession.Hook]" |
| ) -> "SparkSession.Builder": |
| with self._lock: |
| self._hook_factories.append(hook_factory) |
| return self |
| |
| def enableHiveSupport(self) -> "SparkSession.Builder": |
| raise PySparkNotImplementedError( |
| errorClass="NOT_IMPLEMENTED", messageParameters={"feature": "enableHiveSupport"} |
| ) |
| |
| def _apply_options(self, session: "SparkSession") -> None: |
| init_opts = {} |
| for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): |
| init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) |
| |
| # The options are applied after session creation, |
| # so options ["spark.remote", "spark.master"] always take no effect. |
| invalid_opts = ["spark.remote", "spark.master"] |
| |
| with self._lock: |
| opts = {} |
| |
| # Only attempts to set Spark SQL configurations. |
| # If the configurations are static, it might throw an exception so |
| # simply ignore it for now. |
| for k, v in init_opts.items(): |
| if k not in invalid_opts and k.startswith("spark.sql."): |
| opts[k] = v |
| |
| for k, v in self._options.items(): |
| if k not in invalid_opts: |
| opts[k] = v |
| |
| if len(opts) > 0: |
| session.conf._set_all(configs=opts, silent=True) |
| |
| def create(self) -> "SparkSession": |
| has_channel_builder = self._channel_builder is not None |
| has_spark_remote = "spark.remote" in self._options |
| |
| if (has_channel_builder and has_spark_remote) or ( |
| not has_channel_builder and not has_spark_remote |
| ): |
| raise PySparkValueError( |
| errorClass="SESSION_NEED_CONN_STR_OR_BUILDER", messageParameters={} |
| ) |
| |
| if has_channel_builder: |
| assert self._channel_builder is not None |
| session = SparkSession( |
| connection=self._channel_builder, hook_factories=self._hook_factories |
| ) |
| else: |
| spark_remote = to_str(self._options.get("spark.remote")) |
| assert spark_remote is not None |
| session = SparkSession(connection=spark_remote, hook_factories=self._hook_factories) |
| |
| SparkSession._set_default_and_active_session(session) |
| self._apply_options(session) |
| return session |
| |
| def getOrCreate(self) -> "SparkSession": |
| with SparkSession._lock: |
| session = SparkSession.getActiveSession() |
| if session is None: |
| session = SparkSession._get_default_session() |
| if session is None: |
| session = self.create() |
| self._apply_options(session) |
| return session |
| |
| class Hook: |
| """A Hook can be used to inject behavior into the session.""" |
| |
| def on_execute_plan(self, request: "pb2.ExecutePlanRequest") -> "pb2.ExecutePlanRequest": |
| """Called before sending an ExecutePlanRequest. |
| |
| The request is replaced with the one returned by this method. |
| """ |
| return request |
| |
| _client: SparkConnectClient |
| |
| # SPARK-47544: Explicitly declaring this as an identifier instead of a method. |
| # If changing, make sure this bug is not reintroduced. |
| builder: Builder = classproperty(lambda cls: cls.Builder()) # type: ignore |
| builder.__doc__ = PySparkSession.builder.__doc__ |
| |
| def __init__( |
| self, |
| connection: Union[str, DefaultChannelBuilder], |
| userId: Optional[str] = None, |
| hook_factories: Optional[list["Callable[[SparkSession], Hook]"]] = None, |
| ) -> None: |
| """ |
| Creates a new SparkSession for the Spark Connect interface. |
| |
| Parameters |
| ---------- |
| connection: str or class:`ChannelBuilder` |
| Connection string that is used to extract the connection parameters and configure |
| the GRPC connection. Or instance of ChannelBuilder that creates GRPC connection. |
| Defaults to `sc://localhost`. |
| userId : str, optional |
| Optional unique user ID that is used to differentiate multiple users and |
| isolate their Spark Sessions. If the `user_id` is not set, will default to |
| the $USER environment. Defining the user ID as part of the connection string |
| takes precedence. |
| hook_factories: list[Callable[[SparkSession], Hook]], optional |
| Optional list of hook factories for hooks that should be registered for this session. |
| """ |
| hook_factories = hook_factories or [] |
| self._client = SparkConnectClient( |
| connection=connection, |
| user_id=userId, |
| session_hooks=[factory(self) for factory in hook_factories], |
| ) |
| self._session_id = self._client._session_id |
| |
| # Set to false to prevent client.release_session on close() (testing only) |
| self.release_session_on_close = True |
| |
| @classmethod |
| def _set_default_and_active_session(cls, session: "SparkSession") -> None: |
| """ |
| Set the (global) default :class:`SparkSession`, and (thread-local) |
| active :class:`SparkSession` when they are not set yet. |
| """ |
| with cls._lock: |
| if cls._default_session is None: |
| cls._default_session = session |
| if getattr(cls._active_session, "session", None) is None: |
| cls._active_session.session = session |
| |
| @classmethod |
| def _get_default_session(cls) -> Optional["SparkSession"]: |
| s = cls._default_session |
| if s is not None and not s.is_stopped: |
| return s |
| return None |
| |
| @classmethod |
| def getActiveSession(cls) -> Optional["SparkSession"]: |
| s = getattr(cls._active_session, "session", None) |
| if s is not None and not s.is_stopped: |
| return s |
| return None |
| |
| @classmethod |
| def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": |
| """ |
| Internal use only. This method is called from the custom handler |
| generated by __reduce__. To avoid serializing a WeakRef, we create a |
| custom classmethod to instantiate the SparkSession. |
| """ |
| session = SparkSession.getActiveSession() |
| if session is None: |
| raise PySparkRuntimeError( |
| errorClass="NO_ACTIVE_SESSION", |
| messageParameters={}, |
| ) |
| if session._session_id != session_id: |
| raise PySparkAssertionError( |
| "Expected session ID does not match active session ID: " |
| f"{session_id} != {session._session_id}" |
| ) |
| return session |
| |
| getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ |
| |
| @classmethod |
| def active(cls) -> "SparkSession": |
| session = cls.getActiveSession() |
| if session is None: |
| session = cls._get_default_session() |
| if session is None: |
| raise PySparkRuntimeError( |
| errorClass="NO_ACTIVE_OR_DEFAULT_SESSION", |
| messageParameters={}, |
| ) |
| return session |
| |
| active.__doc__ = PySparkSession.active.__doc__ |
| |
| def table(self, tableName: str) -> ParentDataFrame: |
| if not isinstance(tableName, str): |
| raise PySparkTypeError( |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "tableName", "arg_type": type(tableName).__name__}, |
| ) |
| |
| return self.read.table(tableName) |
| |
| table.__doc__ = PySparkSession.table.__doc__ |
| |
| @property |
| def read(self) -> "DataFrameReader": |
| return DataFrameReader(self) |
| |
| read.__doc__ = PySparkSession.read.__doc__ |
| |
| @property |
| def readStream(self) -> "DataStreamReader": |
| return DataStreamReader(self) |
| |
| readStream.__doc__ = PySparkSession.readStream.__doc__ |
| |
| @property |
| def tvf(self) -> "TableValuedFunction": |
| from pyspark.sql.connect.tvf import TableValuedFunction |
| |
| return TableValuedFunction(self) |
| |
| tvf.__doc__ = PySparkSession.tvf.__doc__ |
| |
| def registerProgressHandler(self, handler: "ProgressHandler") -> None: |
| self._client.register_progress_handler(handler) |
| |
| registerProgressHandler.__doc__ = PySparkSession.registerProgressHandler.__doc__ |
| |
| def removeProgressHandler(self, handler: "ProgressHandler") -> None: |
| self._client.remove_progress_handler(handler) |
| |
| removeProgressHandler.__doc__ = PySparkSession.removeProgressHandler.__doc__ |
| |
| def clearProgressHandlers(self) -> None: |
| self._client.clear_progress_handlers() |
| |
| clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__ |
| |
| def _inferSchemaFromList( |
| self, |
| data: Iterable[Any], |
| names: Optional[List[str]], |
| configs: Mapping[str, Optional[str]], |
| ) -> StructType: |
| """ |
| Infer schema from list of Row, dict, or tuple. |
| """ |
| if not data: |
| raise PySparkValueError( |
| errorClass="CANNOT_INFER_EMPTY_SCHEMA", |
| messageParameters={}, |
| ) |
| |
| ( |
| infer_dict_as_struct, |
| infer_array_from_first_element, |
| infer_map_from_first_pair, |
| prefer_timestamp, |
| ) = ( |
| configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"], |
| configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"], |
| configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"], |
| configs["spark.sql.timestampType"], |
| ) |
| return functools.reduce( |
| _merge_type, |
| ( |
| _infer_schema( |
| row, |
| names, |
| infer_dict_as_struct=(infer_dict_as_struct == "true"), |
| infer_array_from_first_element=(infer_array_from_first_element == "true"), |
| infer_map_from_first_pair=(infer_map_from_first_pair == "true"), |
| prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"), |
| ) |
| for row in data |
| ), |
| ) |
| |
| def createDataFrame( |
| self, |
| data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], |
| schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, |
| samplingRatio: Optional[float] = None, |
| verifySchema: Optional[bool] = None, |
| ) -> "ParentDataFrame": |
| assert data is not None |
| if isinstance(data, DataFrame): |
| raise PySparkTypeError( |
| errorClass="INVALID_TYPE", |
| messageParameters={"arg_name": "data", "arg_type": "DataFrame"}, |
| ) |
| |
| if samplingRatio is not None: |
| warnings.warn("'samplingRatio' is ignored. It is not supported with Spark Connect.") |
| |
| if verifySchema is not None: |
| warnings.warn("'verifySchema' is ignored. It is not supported with Spark Connect.") |
| |
| _schema: Optional[Union[AtomicType, StructType]] = None |
| _cols: Optional[List[str]] = None |
| _num_cols: Optional[int] = None |
| |
| if isinstance(schema, str): |
| schema = self.client._analyze( # type: ignore[assignment] |
| method="ddl_parse", ddl_string=schema |
| ).parsed |
| |
| if isinstance(schema, (AtomicType, StructType)): |
| _schema = schema |
| if isinstance(schema, StructType): |
| _num_cols = len(schema.fields) |
| else: |
| _num_cols = 1 |
| |
| elif isinstance(schema, (list, tuple)): |
| # Must re-encode any unicode strings to be consistent with StructField names |
| _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] |
| _num_cols = len(_cols) |
| |
| elif schema is not None: |
| raise PySparkTypeError( |
| errorClass="NOT_LIST_OR_NONE_OR_STRUCT", |
| messageParameters={ |
| "arg_name": "schema", |
| "arg_type": type(schema).__name__, |
| }, |
| ) |
| |
| if isinstance(data, np.ndarray) and data.ndim not in [1, 2]: |
| raise PySparkValueError( |
| errorClass="INVALID_NDARRAY_DIMENSION", |
| messageParameters={"dimensions": "1 or 2"}, |
| ) |
| elif isinstance(data, Sized) and len(data) == 0: |
| if _schema is not None: |
| return DataFrame(LocalRelation(table=None, schema=_schema.json()), self) |
| else: |
| raise PySparkValueError( |
| errorClass="CANNOT_INFER_EMPTY_SCHEMA", |
| messageParameters={}, |
| ) |
| |
| # Get all related configs in a batch |
| configs = self._client.get_config_dict( |
| "spark.sql.timestampType", |
| "spark.sql.session.timeZone", |
| "spark.sql.session.localRelationCacheThreshold", |
| "spark.sql.execution.pandas.convertToArrowArraySafely", |
| "spark.sql.execution.pandas.inferPandasDictAsMap", |
| "spark.sql.pyspark.inferNestedDictAsStruct.enabled", |
| "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", |
| "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", |
| "spark.sql.execution.arrow.useLargeVarTypes", |
| ) |
| timezone = configs["spark.sql.session.timeZone"] |
| prefer_timestamp = configs["spark.sql.timestampType"] |
| prefers_large_types: bool = ( |
| cast(str, configs["spark.sql.execution.arrow.useLargeVarTypes"]).lower() == "true" |
| ) |
| |
| _table: Optional[pa.Table] = None |
| |
| if isinstance(data, pd.DataFrame): |
| # Logic was borrowed from `_create_from_pandas_with_arrow` in |
| # `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate the logics. |
| |
| # If no schema supplied by user then get the names of columns only |
| if schema is None: |
| _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] |
| infer_pandas_dict_as_map = ( |
| configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true" |
| ) |
| if infer_pandas_dict_as_map: |
| struct = StructType() |
| pa_schema = pa.Schema.from_pandas(data) |
| spark_type: Union[MapType, DataType] |
| for field in pa_schema: |
| field_type = field.type |
| if isinstance(field_type, pa.StructType): |
| if len(field_type) == 0: |
| raise PySparkValueError( |
| errorClass="CANNOT_INFER_EMPTY_SCHEMA", |
| messageParameters={}, |
| ) |
| arrow_type = field_type.field(0).type |
| spark_type = MapType(StringType(), from_arrow_type(arrow_type)) |
| else: |
| spark_type = from_arrow_type(field_type) |
| struct.add(field.name, spark_type, nullable=field.nullable) |
| schema = struct |
| elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns): |
| assert isinstance(_cols, list) |
| _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))]) |
| _num_cols = len(_cols) |
| |
| # Determine arrow types to coerce data when creating batches |
| arrow_schema: Optional[pa.Schema] = None |
| spark_types: List[Optional[DataType]] |
| arrow_types: List[Optional[pa.DataType]] |
| if isinstance(schema, StructType): |
| deduped_schema = cast(StructType, _deduplicate_field_names(schema)) |
| spark_types = [field.dataType for field in deduped_schema.fields] |
| arrow_schema = to_arrow_schema( |
| deduped_schema, prefers_large_types=prefers_large_types |
| ) |
| arrow_types = [field.type for field in arrow_schema] |
| _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] |
| elif isinstance(schema, DataType): |
| raise PySparkTypeError( |
| errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW", |
| messageParameters={"data_type": str(schema)}, |
| ) |
| else: |
| # Any timestamps must be coerced to be compatible with Spark |
| spark_types = [ |
| TimestampType() |
| if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype) |
| else DayTimeIntervalType() |
| if is_timedelta64_dtype(t) |
| else None |
| for t in data.dtypes |
| ] |
| arrow_types = [ |
| to_arrow_type(dt, prefers_large_types=prefers_large_types) |
| if dt is not None |
| else None |
| for dt in spark_types |
| ] |
| |
| safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] |
| |
| ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true", False) |
| |
| _table = pa.Table.from_batches( |
| [ |
| ser._create_batch( |
| [ |
| (c, at, st) |
| for (_, c), at, st in zip(data.items(), arrow_types, spark_types) |
| ] |
| ) |
| ] |
| ) |
| |
| if isinstance(schema, StructType): |
| assert arrow_schema is not None |
| _table = _table.rename_columns( |
| cast(StructType, _deduplicate_field_names(schema)).names |
| ).cast(arrow_schema) |
| |
| elif isinstance(data, pa.Table): |
| # If no schema supplied by user then get the names of columns only |
| if schema is None: |
| _cols = data.column_names |
| if isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns): |
| assert isinstance(_cols, list) |
| _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))]) |
| _num_cols = len(_cols) |
| |
| if not isinstance(schema, StructType): |
| schema = from_arrow_schema( |
| data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ" |
| ) |
| |
| _table = ( |
| _check_arrow_table_timestamps_localize(data, schema, True, timezone) |
| .cast( |
| to_arrow_schema( |
| schema, |
| error_on_duplicated_field_names_in_struct=True, |
| prefers_large_types=prefers_large_types, |
| ) |
| ) |
| .rename_columns(schema.names) |
| ) |
| |
| elif isinstance(data, np.ndarray): |
| if _cols is None: |
| if data.ndim == 1 or data.shape[1] == 1: |
| _cols = ["value"] |
| else: |
| _cols = ["_%s" % i for i in range(1, data.shape[1] + 1)] |
| |
| if data.ndim == 1: |
| if 1 != len(_cols): |
| raise PySparkValueError( |
| errorClass="AXIS_LENGTH_MISMATCH", |
| messageParameters={ |
| "expected_length": str(len(_cols)), |
| "actual_length": "1", |
| }, |
| ) |
| |
| _table = pa.Table.from_arrays([pa.array(data)], _cols) |
| else: |
| if data.shape[1] != len(_cols): |
| raise PySparkValueError( |
| errorClass="AXIS_LENGTH_MISMATCH", |
| messageParameters={ |
| "expected_length": str(len(_cols)), |
| "actual_length": str(data.shape[1]), |
| }, |
| ) |
| |
| _table = pa.Table.from_arrays( |
| [pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols |
| ) |
| |
| # The _table should already have the proper column names. |
| _cols = None |
| |
| else: |
| _data = list(data) |
| |
| if isinstance(_data[0], dict): |
| # Sort the data to respect inferred schema. |
| # For dictionaries, we sort the schema in alphabetical order. |
| _data = [dict(sorted(d.items())) if d is not None else None for d in _data] |
| |
| elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr( |
| _data[0], "__dict__" |
| ): |
| # input data can be [1, 2, 3] |
| # we need to convert it to [[1], [2], [3]] to be able to infer schema. |
| _data = [[d] for d in _data] |
| |
| if _schema is not None: |
| if not isinstance(_schema, StructType): |
| _schema = StructType().add("value", _schema) |
| else: |
| _schema = self._inferSchemaFromList(_data, _cols, configs) |
| |
| if _cols is not None and cast(int, _num_cols) < len(_cols): |
| _num_cols = len(_cols) |
| |
| if _has_nulltype(_schema): |
| # For cases like createDataFrame([("Alice", None, 80.1)], schema) |
| # we can not infer the schema from the data itself. |
| raise PySparkValueError( |
| errorClass="CANNOT_DETERMINE_TYPE", messageParameters={} |
| ) |
| |
| from pyspark.sql.conversion import ( |
| LocalDataToArrowConversion, |
| ) |
| |
| # Spark Connect will try its best to build the Arrow table with the |
| # inferred schema in the client side, and then rename the columns and |
| # cast the datatypes in the server side. |
| _table = LocalDataToArrowConversion.convert(_data, _schema, prefers_large_types) |
| |
| # TODO: Beside the validation on number of columns, we should also check |
| # whether the Arrow Schema is compatible with the user provided Schema. |
| if _num_cols is not None and _num_cols != _table.shape[1]: |
| raise PySparkValueError( |
| errorClass="AXIS_LENGTH_MISMATCH", |
| messageParameters={ |
| "expected_length": str(_num_cols), |
| "actual_length": str(_table.shape[1]), |
| }, |
| ) |
| |
| if _schema is not None: |
| local_relation = LocalRelation(_table, schema=_schema.json()) |
| else: |
| local_relation = LocalRelation(_table) |
| |
| cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"] |
| plan: LogicalPlan = local_relation |
| if cache_threshold is not None and int(cache_threshold) <= _table.nbytes: |
| plan = CachedLocalRelation(self._cache_local_relation(local_relation)) |
| |
| df = DataFrame(plan, self) |
| if _cols is not None and len(_cols) > 0: |
| df = df.toDF(*_cols) # type: ignore[assignment] |
| return df |
| |
| createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ |
| |
| def sql( |
| self, |
| sqlQuery: str, |
| args: Optional[Union[Dict[str, Any], List]] = None, |
| **kwargs: Any, |
| ) -> "ParentDataFrame": |
| _args = [] |
| _named_args = {} |
| if args is not None: |
| if isinstance(args, Dict): |
| for k, v in args.items(): |
| assert isinstance(k, str) |
| _named_args[k] = F.lit(v) |
| elif isinstance(args, List): |
| _args = [F.lit(v) for v in args] |
| else: |
| raise PySparkTypeError( |
| errorClass="INVALID_TYPE", |
| messageParameters={"arg_name": "args", "arg_type": type(args).__name__}, |
| ) |
| |
| _views: List[SubqueryAlias] = [] |
| if len(kwargs) > 0: |
| from pyspark.sql.connect.sql_formatter import SQLStringFormatter |
| |
| formatter = SQLStringFormatter(self) |
| sqlQuery = formatter.format(sqlQuery, **kwargs) |
| |
| for df, name in formatter._temp_views: |
| _views.append(SubqueryAlias(df._plan, name)) |
| |
| cmd = SQL(sqlQuery, _args, _named_args, _views) |
| data, properties, ei = self.client.execute_command(cmd.command(self._client)) |
| if "sql_command_result" in properties: |
| df = DataFrame(CachedRelation(properties["sql_command_result"]), self) |
| # A command result contains the execution. |
| df._execution_info = ei |
| return df |
| else: |
| return DataFrame(cmd, self) |
| |
| sql.__doc__ = PySparkSession.sql.__doc__ |
| |
| def range( |
| self, |
| start: int, |
| end: Optional[int] = None, |
| step: int = 1, |
| numPartitions: Optional[int] = None, |
| ) -> ParentDataFrame: |
| if end is None: |
| actual_end = start |
| start = 0 |
| else: |
| actual_end = end |
| |
| if numPartitions is not None: |
| numPartitions = int(numPartitions) |
| |
| return DataFrame( |
| Range( |
| start=int(start), end=int(actual_end), step=int(step), num_partitions=numPartitions |
| ), |
| self, |
| ) |
| |
| range.__doc__ = PySparkSession.range.__doc__ |
| |
| @functools.cached_property |
| def catalog(self) -> "Catalog": |
| from pyspark.sql.connect.catalog import Catalog |
| |
| return Catalog(self) |
| |
| catalog.__doc__ = PySparkSession.catalog.__doc__ |
| |
| def __del__(self) -> None: |
| try: |
| # StreamingQueryManager has client states that needs to be cleaned up |
| if hasattr(self, "_sqm"): |
| self._sqm.close() |
| # Try its best to close. |
| self.client.close() |
| except Exception: |
| pass |
| |
| def interruptAll(self) -> List[str]: |
| op_ids = self.client.interrupt_all() |
| assert op_ids is not None |
| return op_ids |
| |
| interruptAll.__doc__ = PySparkSession.interruptAll.__doc__ |
| |
| def interruptTag(self, tag: str) -> List[str]: |
| op_ids = self.client.interrupt_tag(tag) |
| assert op_ids is not None |
| return op_ids |
| |
| interruptTag.__doc__ = PySparkSession.interruptTag.__doc__ |
| |
| def interruptOperation(self, op_id: str) -> List[str]: |
| op_ids = self.client.interrupt_operation(op_id) |
| assert op_ids is not None |
| return op_ids |
| |
| interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__ |
| |
| def addTag(self, tag: str) -> None: |
| self.client.add_tag(tag) |
| |
| addTag.__doc__ = PySparkSession.addTag.__doc__ |
| |
| def removeTag(self, tag: str) -> None: |
| self.client.remove_tag(tag) |
| |
| removeTag.__doc__ = PySparkSession.removeTag.__doc__ |
| |
| def getTags(self) -> Set[str]: |
| return self.client.get_tags() |
| |
| getTags.__doc__ = PySparkSession.getTags.__doc__ |
| |
| def clearTags(self) -> None: |
| return self.client.clear_tags() |
| |
| clearTags.__doc__ = PySparkSession.clearTags.__doc__ |
| |
| def stop(self) -> None: |
| """ |
| Release the current session and close the GRPC connection to the Spark Connect server. |
| Reset the active session so that calls to getOrCreate() creates a new session. |
| If the session was created in local mode, the Spark Connect server running locally is also |
| terminated. |
| |
| This API is best-effort and idempotent, i.e., if any of the operations fail, the API will |
| not produce an error. Stopping an already stopped session is a no-op. |
| """ |
| |
| # Whereas the regular PySpark session immediately terminates the Spark Context |
| # itself, meaning that stopping all Spark sessions, this will only stop this one session |
| # on the server. |
| # It is controversial to follow the existing the regular Spark session's behavior |
| # specifically in Spark Connect the Spark Connect server is designed for |
| # multi-tenancy - the remote client side cannot just stop the server and stop |
| # other remote clients being used from other users. |
| with SparkSession._lock: |
| if not self.is_stopped and self.release_session_on_close: |
| try: |
| self.client.release_session() |
| except Exception as e: |
| logger.warn(f"session.stop(): Session could not be released. Error: ${e}") |
| |
| try: |
| self.client.close() |
| except Exception as e: |
| logger.warn(f"session.stop(): Client could not be closed. Error: ${e}") |
| |
| if self is SparkSession._default_session: |
| SparkSession._default_session = None |
| if self is getattr(SparkSession._active_session, "session", None): |
| SparkSession._active_session.session = None |
| |
| if "SPARK_LOCAL_REMOTE" in os.environ: |
| # When local mode is in use, follow the regular Spark session's |
| # behavior by terminating the Spark Connect server, |
| # meaning that you can stop local mode, and restart the Spark Connect |
| # client with a different remote address. |
| if PySparkSession._activeSession is not None: |
| try: |
| PySparkSession._activeSession.stop() |
| except Exception as e: |
| logger.warn( |
| "session.stop(): Local Spark Connect Server could not be stopped. " |
| f"Error: ${e}" |
| ) |
| del os.environ["SPARK_LOCAL_REMOTE"] |
| del os.environ["SPARK_CONNECT_MODE_ENABLED"] |
| if "SPARK_REMOTE" in os.environ: |
| del os.environ["SPARK_REMOTE"] |
| |
| @property |
| def is_stopped(self) -> bool: |
| """ |
| Returns if this session was stopped |
| """ |
| return self.client.is_closed |
| |
| @property |
| def conf(self) -> RuntimeConf: |
| return RuntimeConf(self.client) |
| |
| conf.__doc__ = PySparkSession.conf.__doc__ |
| |
| @property |
| def streams(self) -> "StreamingQueryManager": |
| if hasattr(self, "_sqm"): |
| return self._sqm |
| self._sqm: StreamingQueryManager = StreamingQueryManager(self) |
| return self._sqm |
| |
| streams.__doc__ = PySparkSession.streams.__doc__ |
| |
| def __getattr__(self, name: str) -> Any: |
| if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext", "newSession"]: |
| raise PySparkAttributeError( |
| errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED", messageParameters={"attr_name": name} |
| ) |
| return object.__getattribute__(self, name) |
| |
| @property |
| def udf(self) -> "UDFRegistration": |
| from pyspark.sql.connect.udf import UDFRegistration |
| |
| return UDFRegistration(self) |
| |
| udf.__doc__ = PySparkSession.udf.__doc__ |
| |
| @property |
| def udtf(self) -> "UDTFRegistration": |
| from pyspark.sql.connect.udtf import UDTFRegistration |
| |
| return UDTFRegistration(self) |
| |
| udtf.__doc__ = PySparkSession.udtf.__doc__ |
| |
| @property |
| def dataSource(self) -> "DataSourceRegistration": |
| from pyspark.sql.connect.datasource import DataSourceRegistration |
| |
| return DataSourceRegistration(self) |
| |
| dataSource.__doc__ = PySparkSession.dataSource.__doc__ |
| |
| @functools.cached_property |
| def version(self) -> str: |
| result = self._client._analyze(method="spark_version").spark_version |
| assert result is not None |
| return result |
| |
| version.__doc__ = PySparkSession.version.__doc__ |
| |
| @property |
| def client(self) -> "SparkConnectClient": |
| return self._client |
| |
| client.__doc__ = PySparkSession.client.__doc__ |
| |
| def addArtifacts( |
| self, *path: str, pyfile: bool = False, archive: bool = False, file: bool = False |
| ) -> None: |
| if sum([file, pyfile, archive]) > 1: |
| raise PySparkValueError( |
| errorClass="INVALID_MULTIPLE_ARGUMENT_CONDITIONS", |
| messageParameters={ |
| "arg_names": "'pyfile', 'archive' and/or 'file'", |
| "condition": "True together", |
| }, |
| ) |
| self._client.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file) |
| |
| addArtifacts.__doc__ = PySparkSession.addArtifacts.__doc__ |
| |
| addArtifact = addArtifacts |
| |
| def _cache_local_relation(self, local_relation: LocalRelation) -> str: |
| """ |
| Cache the local relation at the server side if it has not been cached yet. |
| """ |
| serialized = local_relation.serialize(self._client) |
| return self._client.cache_artifact(serialized) |
| |
| def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: |
| if urllib.parse.urlparse(dest_path).scheme: |
| raise PySparkValueError( |
| errorClass="NO_SCHEMA_AND_DRIVER_DEFAULT_SCHEME", |
| messageParameters={"arg_name": "dest_path"}, |
| ) |
| self._client.copy_from_local_to_fs(local_path, dest_path) |
| |
| copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__ |
| |
| def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame": |
| """ |
| In internal API to reference a runtime DataFrame on the server side. |
| This is used in ForeachBatch() runner, where the remote DataFrame refers to the |
| output of a micro batch. |
| """ |
| return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self) |
| |
| @staticmethod |
| def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: |
| """ |
| Starts the Spark Connect server given the master (thread-unsafe). |
| |
| 1. Temporarily remove all states for Spark Connect, for example, ``SPARK_REMOTE`` |
| environment variable. |
| |
| 2. Starts a regular Spark session that automatically starts a Spark Connect server |
| via ``spark.plugins`` feature. |
| |
| Returns the authentication token that should be used to connect to this session. |
| """ |
| from pyspark import SparkContext, SparkConf |
| |
| session = PySparkSession._instantiatedSession |
| if session is None or session._sc._jsc is None: |
| init_opts = {} |
| for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): |
| init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) |
| init_opts.update(opts) |
| opts = init_opts |
| |
| # Configurations to be overwritten |
| overwrite_conf = opts |
| overwrite_conf["spark.master"] = master |
| if "spark.remote" in overwrite_conf: |
| del overwrite_conf["spark.remote"] |
| if "spark.api.mode" in overwrite_conf: |
| del overwrite_conf["spark.api.mode"] |
| |
| # Check for a user provided authentication token, creating a new one if not, |
| # and make sure it's set in the environment, |
| if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ: |
| os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get( |
| "spark.connect.authenticate.token", str(uuid.uuid4()) |
| ) |
| |
| # Configurations to be set if unset. |
| default_conf = { |
| "spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin", |
| "spark.sql.artifact.isolation.enabled": "true", |
| "spark.sql.artifact.isolation.alwaysApplyClassloader": "true", |
| } |
| |
| if "SPARK_TESTING" in os.environ: |
| # For testing, we use 0 to use an ephemeral port to allow parallel testing. |
| # See also SPARK-42272. |
| overwrite_conf["spark.connect.grpc.binding.port"] = "0" |
| |
| origin_remote = os.environ.get("SPARK_REMOTE", None) |
| try: |
| # So SparkSubmit thinks no remote is set in order to |
| # start the regular PySpark session. |
| if origin_remote is not None: |
| del os.environ["SPARK_REMOTE"] |
| |
| # The regular PySpark session is registered as an active session |
| # so would not be garbage-collected. |
| conf = SparkConf(loadDefaults=True) |
| conf.setAll(list(overwrite_conf.items())).setAll(list(default_conf.items())) |
| PySparkSession(SparkContext.getOrCreate(conf)) |
| |
| # Lastly only keep runtime configurations because other configurations are |
| # disallowed to set in the regular Spark Connect session. |
| utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] |
| runtime_conf_keys = [c._1() for c in utl.listRuntimeSQLConfigs()] |
| new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys} |
| opts.clear() |
| opts.update(new_opts) |
| finally: |
| if origin_remote is not None: |
| os.environ["SPARK_REMOTE"] = origin_remote |
| else: |
| raise PySparkRuntimeError( |
| errorClass="SESSION_OR_CONTEXT_EXISTS", |
| messageParameters={}, |
| ) |
| |
| @property |
| def session_id(self) -> str: |
| return self._session_id |
| |
| @property |
| def _profiler_collector(self) -> ProfilerCollector: |
| return self._client._profiler_collector |
| |
| @property |
| def profile(self) -> Profile: |
| return Profile(self._client._profiler_collector) |
| |
| profile.__doc__ = PySparkSession.profile.__doc__ |
| |
| def __reduce__(self) -> Tuple: |
| """ |
| This method is called when the object is pickled. It returns a tuple of the object's |
| constructor function, arguments to it and the local state of the object. |
| This function is supposed to only be used when the active spark session that is pickled |
| is the same active spark session that is unpickled. |
| """ |
| |
| def creator(old_session_id: str) -> "SparkSession": |
| # We cannot perform the checks for session matching here because accessing the |
| # session ID property causes the serialization of a WeakRef and in turn breaks |
| # the serialization. |
| return SparkSession._getActiveSessionIfMatches(old_session_id) |
| |
| return creator, (self._session_id,) |
| |
| def _to_ddl(self, struct: StructType) -> str: |
| ddl = self._client._analyze(method="json_to_ddl", json_string=struct.json()).ddl_string |
| assert ddl is not None |
| return ddl |
| |
| def _parse_ddl(self, ddl: str) -> DataType: |
| dt = self._client._analyze(method="ddl_parse", ddl_string=ddl).parsed |
| assert dt is not None |
| return dt |
| |
| |
| SparkSession.__doc__ = PySparkSession.__doc__ |
| |
| |
| def _test() -> None: |
| import os |
| import sys |
| import doctest |
| from pyspark.sql import SparkSession as PySparkSession |
| import pyspark.sql.connect.session |
| |
| globs = pyspark.sql.connect.session.__dict__.copy() |
| globs["spark"] = ( |
| PySparkSession.builder.appName("sql.connect.session tests") |
| .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) |
| .getOrCreate() |
| ) |
| |
| # Uses PySpark session to test builder. |
| globs["SparkSession"] = PySparkSession |
| # Spark Connect does not support to set master together. |
| pyspark.sql.connect.session.SparkSession.__doc__ = None |
| del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__ |
| |
| (failure_count, test_count) = doctest.testmod( |
| pyspark.sql.connect.session, |
| globs=globs, |
| optionflags=doctest.ELLIPSIS |
| | doctest.NORMALIZE_WHITESPACE |
| | doctest.IGNORE_EXCEPTION_DETAIL, |
| ) |
| |
| globs["spark"].stop() |
| |
| if failure_count: |
| sys.exit(-1) |
| |
| |
| if __name__ == "__main__": |
| _test() |