blob: b688ca022c8c95c3fe72833abeea39a29bd6f1a1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import threading
import os
import warnings
from collections.abc import Sized
from functools import reduce
from threading import RLock
from typing import (
Optional,
Any,
Union,
Dict,
List,
Tuple,
Set,
cast,
overload,
Iterable,
TYPE_CHECKING,
ClassVar,
)
import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.api.types import ( # type: ignore[attr-defined]
is_datetime64_dtype,
is_timedelta64_dtype,
)
import urllib
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.loose_version import LooseVersion
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.plan import (
SQL,
Range,
LocalRelation,
LogicalPlan,
CachedLocalRelation,
CachedRelation,
CachedRemoteRelation,
SubqueryAlias,
)
from pyspark.sql.connect.functions import builtin as F
from pyspark.sql.connect.profiler import ProfilerCollector
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming.readwriter import DataStreamReader
from pyspark.sql.connect.streaming.query import StreamingQueryManager
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
from pyspark.sql.pandas.types import (
to_arrow_schema,
to_arrow_type,
_deduplicate_field_names,
from_arrow_type,
)
from pyspark.sql.profiler import Profile
from pyspark.sql.session import classproperty, SparkSession as PySparkSession
from pyspark.sql.types import (
_infer_schema,
_has_nulltype,
_merge_type,
Row,
DataType,
DayTimeIntervalType,
StructType,
AtomicType,
TimestampType,
MapType,
StringType,
)
from pyspark.sql.utils import to_str
from pyspark.errors import (
PySparkAttributeError,
PySparkNotImplementedError,
PySparkRuntimeError,
PySparkValueError,
PySparkTypeError,
PySparkAssertionError,
)
if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.catalog import Catalog
from pyspark.sql.connect.udf import UDFRegistration
from pyspark.sql.connect.udtf import UDTFRegistration
from pyspark.sql.connect.shell.progress import ProgressHandler
from pyspark.sql.connect.datasource import DataSourceRegistration
try:
import memory_profiler # noqa: F401
has_memory_profiler = True
except Exception:
has_memory_profiler = False
class SparkSession:
# The active SparkSession for the current thread
_active_session: ClassVar[threading.local] = threading.local()
# Reference to the root SparkSession
_default_session: ClassVar[Optional["SparkSession"]] = None
_lock: ClassVar[RLock] = RLock()
class Builder:
"""Builder for :class:`SparkSession`."""
_lock = RLock()
def __init__(self) -> None:
self._options: Dict[str, Any] = {}
self._channel_builder: Optional[DefaultChannelBuilder] = None
@overload
def config(self, key: str, value: Any) -> "SparkSession.Builder":
...
@overload
def config(self, *, map: Dict[str, "OptionalPrimitiveType"]) -> "SparkSession.Builder":
...
def config(
self,
key: Optional[str] = None,
value: Optional[Any] = None,
*,
map: Optional[Dict[str, "OptionalPrimitiveType"]] = None,
) -> "SparkSession.Builder":
with self._lock:
if map is not None:
for k, v in map.items():
self._options[k] = to_str(v)
else:
self._options[cast(str, key)] = to_str(value)
return self
def master(self, master: str) -> "SparkSession.Builder":
return self
def appName(self, name: str) -> "SparkSession.Builder":
return self.config("spark.app.name", name)
def remote(self, location: str = "sc://localhost") -> "SparkSession.Builder":
return self.config("spark.remote", location)
def channelBuilder(self, channelBuilder: DefaultChannelBuilder) -> "SparkSession.Builder":
"""Uses custom :class:`ChannelBuilder` implementation, when there is a need
to customize the behavior for creation of GRPC connections.
.. versionadded:: 3.5.0
An example to use this class looks like this:
.. code-block:: python
from pyspark.sql.connect import SparkSession, ChannelBuilder
class CustomChannelBuilder(ChannelBuilder):
...
custom_channel_builder = CustomChannelBuilder(...)
spark = SparkSession.builder().channelBuilder(custom_channel_builder).getOrCreate()
Returns
-------
:class:`SparkSession.Builder`
"""
with self._lock:
# self._channel_builder is a separate field, because it may hold the state
# and cannot be serialized with to_str()
self._channel_builder = channelBuilder
return self
def enableHiveSupport(self) -> "SparkSession.Builder":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED", message_parameters={"feature": "enableHiveSupport"}
)
def _apply_options(self, session: "SparkSession") -> None:
with self._lock:
for k, v in self._options.items():
# the options are applied after session creation,
# so following options always take no effect
if k not in [
"spark.remote",
"spark.master",
]:
try:
session.conf.set(k, v)
except Exception as e:
warnings.warn(str(e))
def create(self) -> "SparkSession":
has_channel_builder = self._channel_builder is not None
has_spark_remote = "spark.remote" in self._options
if (has_channel_builder and has_spark_remote) or (
not has_channel_builder and not has_spark_remote
):
raise PySparkValueError(
error_class="SESSION_NEED_CONN_STR_OR_BUILDER", message_parameters={}
)
if has_channel_builder:
assert self._channel_builder is not None
session = SparkSession(connection=self._channel_builder)
else:
spark_remote = to_str(self._options.get("spark.remote"))
assert spark_remote is not None
session = SparkSession(connection=spark_remote)
SparkSession._set_default_and_active_session(session)
self._apply_options(session)
return session
def getOrCreate(self) -> "SparkSession":
with SparkSession._lock:
session = SparkSession.getActiveSession()
if session is None or session.is_stopped:
session = SparkSession._default_session
if session is None or session.is_stopped:
session = self.create()
self._apply_options(session)
return session
_client: SparkConnectClient
# SPARK-47544: Explicitly declaring this as an identifier instead of a method.
# If changing, make sure this bug is not reintroduced.
builder: Builder = classproperty(lambda cls: cls.Builder()) # type: ignore
builder.__doc__ = PySparkSession.builder.__doc__
def __init__(self, connection: Union[str, DefaultChannelBuilder], userId: Optional[str] = None):
"""
Creates a new SparkSession for the Spark Connect interface.
Parameters
----------
connection: str or class:`ChannelBuilder`
Connection string that is used to extract the connection parameters and configure
the GRPC connection. Or instance of ChannelBuilder that creates GRPC connection.
Defaults to `sc://localhost`.
userId : str, optional
Optional unique user ID that is used to differentiate multiple users and
isolate their Spark Sessions. If the `user_id` is not set, will default to
the $USER environment. Defining the user ID as part of the connection string
takes precedence.
"""
self._client = SparkConnectClient(connection=connection, user_id=userId)
self._session_id = self._client._session_id
# Set to false to prevent client.release_session on close() (testing only)
self.release_session_on_close = True
@classmethod
def _set_default_and_active_session(cls, session: "SparkSession") -> None:
"""
Set the (global) default :class:`SparkSession`, and (thread-local)
active :class:`SparkSession` when they are not set yet.
"""
with cls._lock:
if cls._default_session is None:
cls._default_session = session
if getattr(cls._active_session, "session", None) is None:
cls._active_session.session = session
@classmethod
def getActiveSession(cls) -> Optional["SparkSession"]:
return getattr(cls._active_session, "session", None)
@classmethod
def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession":
"""
Internal use only. This method is called from the custom handler
generated by __reduce__. To avoid serializing a WeakRef, we create a
custom classmethod to instantiate the SparkSession.
"""
session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(
error_class="NO_ACTIVE_SESSION",
message_parameters={},
)
if session._session_id != session_id:
raise PySparkAssertionError(
"Expected session ID does not match active session ID: "
f"{session_id} != {session._session_id}"
)
return session
getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__
@classmethod
def active(cls) -> "SparkSession":
session = cls.getActiveSession()
if session is None:
session = cls._default_session
if session is None:
raise PySparkRuntimeError(
error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
message_parameters={},
)
return session
active.__doc__ = PySparkSession.active.__doc__
def table(self, tableName: str) -> ParentDataFrame:
if not isinstance(tableName, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "tableName", "arg_type": type(tableName).__name__},
)
return self.read.table(tableName)
table.__doc__ = PySparkSession.table.__doc__
@property
def read(self) -> "DataFrameReader":
return DataFrameReader(self)
read.__doc__ = PySparkSession.read.__doc__
@property
def readStream(self) -> "DataStreamReader":
return DataStreamReader(self)
readStream.__doc__ = PySparkSession.readStream.__doc__
def registerProgressHandler(self, handler: "ProgressHandler") -> None:
self._client.register_progress_handler(handler)
registerProgressHandler.__doc__ = PySparkSession.registerProgressHandler.__doc__
def removeProgressHandler(self, handler: "ProgressHandler") -> None:
self._client.remove_progress_handler(handler)
removeProgressHandler.__doc__ = PySparkSession.removeProgressHandler.__doc__
def clearProgressHandlers(self) -> None:
self._client.clear_progress_handlers()
clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__
def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
) -> StructType:
"""
Infer schema from list of Row, dict, or tuple.
"""
if not data:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
(
infer_dict_as_struct,
infer_array_from_first_element,
prefer_timestamp_ntz,
) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.timestampType",
)
return reduce(
_merge_type,
(
_infer_schema(
row,
names,
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
)
for row in data
),
)
def createDataFrame(
self,
data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: Optional[bool] = None,
) -> "ParentDataFrame":
assert data is not None
if isinstance(data, DataFrame):
raise PySparkTypeError(
error_class="INVALID_TYPE",
message_parameters={"arg_name": "data", "arg_type": "DataFrame"},
)
if samplingRatio is not None:
warnings.warn("'samplingRatio' is ignored. It is not supported with Spark Connect.")
if verifySchema is not None:
warnings.warn("'verifySchema' is ignored. It is not supported with Spark Connect.")
_schema: Optional[Union[AtomicType, StructType]] = None
_cols: Optional[List[str]] = None
_num_cols: Optional[int] = None
if isinstance(schema, str):
schema = self.client._analyze( # type: ignore[assignment]
method="ddl_parse", ddl_string=schema
).parsed
if isinstance(schema, (AtomicType, StructType)):
_schema = schema
if isinstance(schema, StructType):
_num_cols = len(schema.fields)
else:
_num_cols = 1
elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with StructField names
_cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]
_num_cols = len(_cols)
elif schema is not None:
raise PySparkTypeError(
error_class="NOT_LIST_OR_NONE_OR_STRUCT",
message_parameters={
"arg_name": "schema",
"arg_type": type(schema).__name__,
},
)
if isinstance(data, np.ndarray) and data.ndim not in [1, 2]:
raise PySparkValueError(
error_class="INVALID_NDARRAY_DIMENSION",
message_parameters={"dimensions": "1 or 2"},
)
elif isinstance(data, Sized) and len(data) == 0:
if _schema is not None:
return DataFrame(LocalRelation(table=None, schema=_schema.json()), self)
else:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
_table: Optional[pa.Table] = None
if isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
# `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate the logics.
# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
== "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
pa_schema = pa.Schema.from_pandas(data)
spark_type: Union[MapType, DataType]
for field in pa_schema:
field_type = field.type
if isinstance(field_type, pa.StructType):
if len(field_type) == 0:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
arrow_type = field_type.field(0).type
spark_type = MapType(StringType(), from_arrow_type(arrow_type))
else:
spark_type = from_arrow_type(field_type)
struct.add(field.name, spark_type, nullable=field.nullable)
schema = struct
elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns):
assert isinstance(_cols, list)
_cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))])
_num_cols = len(_cols)
# Determine arrow types to coerce data when creating batches
arrow_schema: Optional[pa.Schema] = None
spark_types: List[Optional[DataType]]
arrow_types: List[Optional[pa.DataType]]
if isinstance(schema, StructType):
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
spark_types = [field.dataType for field in deduped_schema.fields]
arrow_schema = to_arrow_schema(deduped_schema)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]
elif isinstance(schema, DataType):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
message_parameters={"data_type": str(schema)},
)
else:
# Any timestamps must be coerced to be compatible with Spark
spark_types = [
TimestampType()
if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype)
else DayTimeIntervalType()
if is_timedelta64_dtype(t)
else None
for t in data.dtypes
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
_table = pa.Table.from_batches(
[
ser._create_batch(
[
(c, at, st)
for (_, c), at, st in zip(data.items(), arrow_types, spark_types)
]
)
]
)
if isinstance(schema, StructType):
assert arrow_schema is not None
_table = _table.rename_columns(
cast(StructType, _deduplicate_field_names(schema)).names
).cast(arrow_schema)
elif isinstance(data, np.ndarray):
if _cols is None:
if data.ndim == 1 or data.shape[1] == 1:
_cols = ["value"]
else:
_cols = ["_%s" % i for i in range(1, data.shape[1] + 1)]
if data.ndim == 1:
if 1 != len(_cols):
raise PySparkValueError(
error_class="AXIS_LENGTH_MISMATCH",
message_parameters={
"expected_length": str(len(_cols)),
"actual_length": "1",
},
)
_table = pa.Table.from_arrays([pa.array(data)], _cols)
else:
if data.shape[1] != len(_cols):
raise PySparkValueError(
error_class="AXIS_LENGTH_MISMATCH",
message_parameters={
"expected_length": str(len(_cols)),
"actual_length": str(data.shape[1]),
},
)
_table = pa.Table.from_arrays(
[pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols
)
# The _table should already have the proper column names.
_cols = None
else:
_data = list(data)
if isinstance(_data[0], dict):
# Sort the data to respect inferred schema.
# For dictionaries, we sort the schema in alphabetical order.
_data = [dict(sorted(d.items())) if d is not None else None for d in _data]
elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr(
_data[0], "__dict__"
):
# input data can be [1, 2, 3]
# we need to convert it to [[1], [2], [3]] to be able to infer schema.
_data = [[d] for d in _data]
if _schema is not None:
if not isinstance(_schema, StructType):
_schema = StructType().add("value", _schema)
else:
_schema = self._inferSchemaFromList(_data, _cols)
if _cols is not None and cast(int, _num_cols) < len(_cols):
_num_cols = len(_cols)
if _has_nulltype(_schema):
# For cases like createDataFrame([("Alice", None, 80.1)], schema)
# we can not infer the schema from the data itself.
raise PySparkValueError(
error_class="CANNOT_DETERMINE_TYPE", message_parameters={}
)
from pyspark.sql.connect.conversion import LocalDataToArrowConversion
# Spark Connect will try its best to build the Arrow table with the
# inferred schema in the client side, and then rename the columns and
# cast the datatypes in the server side.
_table = LocalDataToArrowConversion.convert(_data, _schema)
# TODO: Beside the validation on number of columns, we should also check
# whether the Arrow Schema is compatible with the user provided Schema.
if _num_cols is not None and _num_cols != _table.shape[1]:
raise PySparkValueError(
error_class="AXIS_LENGTH_MISMATCH",
message_parameters={
"expected_length": str(_num_cols),
"actual_length": str(_table.shape[1]),
},
)
if _schema is not None:
local_relation = LocalRelation(_table, schema=_schema.json())
else:
local_relation = LocalRelation(_table)
cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
plan: LogicalPlan = local_relation
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
plan = CachedLocalRelation(self._cache_local_relation(local_relation))
df = DataFrame(plan, self)
if _cols is not None and len(_cols) > 0:
df = df.toDF(*_cols) # type: ignore[assignment]
return df
createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
def sql(
self,
sqlQuery: str,
args: Optional[Union[Dict[str, Any], List]] = None,
**kwargs: Any,
) -> "ParentDataFrame":
_args = []
_named_args = {}
if args is not None:
if isinstance(args, Dict):
for k, v in args.items():
assert isinstance(k, str)
_named_args[k] = F.lit(v)
elif isinstance(args, List):
_args = [F.lit(v) for v in args]
else:
raise PySparkTypeError(
error_class="INVALID_TYPE",
message_parameters={"arg_name": "args", "arg_type": type(args).__name__},
)
_views: List[SubqueryAlias] = []
if len(kwargs) > 0:
from pyspark.sql.connect.sql_formatter import SQLStringFormatter
formatter = SQLStringFormatter(self)
sqlQuery = formatter.format(sqlQuery, **kwargs)
for df, name in formatter._temp_views:
_views.append(SubqueryAlias(df._plan, name))
cmd = SQL(sqlQuery, _args, _named_args, _views)
data, properties = self.client.execute_command(cmd.command(self._client))
if "sql_command_result" in properties:
return DataFrame(CachedRelation(properties["sql_command_result"]), self)
else:
return DataFrame(cmd, self)
sql.__doc__ = PySparkSession.sql.__doc__
def range(
self,
start: int,
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
) -> ParentDataFrame:
if end is None:
actual_end = start
start = 0
else:
actual_end = end
if numPartitions is not None:
numPartitions = int(numPartitions)
return DataFrame(
Range(
start=int(start), end=int(actual_end), step=int(step), num_partitions=numPartitions
),
self,
)
range.__doc__ = PySparkSession.range.__doc__
@property
def catalog(self) -> "Catalog":
from pyspark.sql.connect.catalog import Catalog
if not hasattr(self, "_catalog"):
self._catalog = Catalog(self)
return self._catalog
catalog.__doc__ = PySparkSession.catalog.__doc__
def __del__(self) -> None:
try:
# StreamingQueryManager has client states that needs to be cleaned up
if hasattr(self, "_sqm"):
self._sqm.close()
# Try its best to close.
self.client.close()
except Exception:
pass
def interruptAll(self) -> List[str]:
op_ids = self.client.interrupt_all()
assert op_ids is not None
return op_ids
interruptAll.__doc__ = PySparkSession.interruptAll.__doc__
def interruptTag(self, tag: str) -> List[str]:
op_ids = self.client.interrupt_tag(tag)
assert op_ids is not None
return op_ids
interruptTag.__doc__ = PySparkSession.interruptTag.__doc__
def interruptOperation(self, op_id: str) -> List[str]:
op_ids = self.client.interrupt_operation(op_id)
assert op_ids is not None
return op_ids
interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__
def addTag(self, tag: str) -> None:
self.client.add_tag(tag)
addTag.__doc__ = PySparkSession.addTag.__doc__
def removeTag(self, tag: str) -> None:
self.client.remove_tag(tag)
removeTag.__doc__ = PySparkSession.removeTag.__doc__
def getTags(self) -> Set[str]:
return self.client.get_tags()
getTags.__doc__ = PySparkSession.getTags.__doc__
def clearTags(self) -> None:
return self.client.clear_tags()
clearTags.__doc__ = PySparkSession.clearTags.__doc__
def stop(self) -> None:
# Whereas the regular PySpark session immediately terminates the Spark Context
# itself, meaning that stopping all Spark sessions, this will only stop this one session
# on the server.
# It is controversial to follow the existing the regular Spark session's behavior
# specifically in Spark Connect the Spark Connect server is designed for
# multi-tenancy - the remote client side cannot just stop the server and stop
# other remote clients being used from other users.
with SparkSession._lock:
if not self.is_stopped and self.release_session_on_close:
self.client.release_session()
self.client.close()
if self is SparkSession._default_session:
SparkSession._default_session = None
if self is getattr(SparkSession._active_session, "session", None):
SparkSession._active_session.session = None
if "SPARK_LOCAL_REMOTE" in os.environ:
# When local mode is in use, follow the regular Spark session's
# behavior by terminating the Spark Connect server,
# meaning that you can stop local mode, and restart the Spark Connect
# client with a different remote address.
if PySparkSession._activeSession is not None:
PySparkSession._activeSession.stop()
del os.environ["SPARK_LOCAL_REMOTE"]
del os.environ["SPARK_CONNECT_MODE_ENABLED"]
if "SPARK_REMOTE" in os.environ:
del os.environ["SPARK_REMOTE"]
stop.__doc__ = PySparkSession.stop.__doc__
@property
def is_stopped(self) -> bool:
"""
Returns if this session was stopped
"""
return self.client.is_closed
@property
def conf(self) -> RuntimeConf:
return RuntimeConf(self.client)
conf.__doc__ = PySparkSession.conf.__doc__
@property
def streams(self) -> "StreamingQueryManager":
if hasattr(self, "_sqm"):
return self._sqm
self._sqm: StreamingQueryManager = StreamingQueryManager(self)
return self._sqm
streams.__doc__ = PySparkSession.streams.__doc__
def __getattr__(self, name: str) -> Any:
if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext", "newSession"]:
raise PySparkAttributeError(
error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
)
return object.__getattribute__(self, name)
@property
def udf(self) -> "UDFRegistration":
from pyspark.sql.connect.udf import UDFRegistration
return UDFRegistration(self)
udf.__doc__ = PySparkSession.udf.__doc__
@property
def udtf(self) -> "UDTFRegistration":
from pyspark.sql.connect.udtf import UDTFRegistration
return UDTFRegistration(self)
udtf.__doc__ = PySparkSession.udtf.__doc__
@property
def dataSource(self) -> "DataSourceRegistration":
from pyspark.sql.connect.datasource import DataSourceRegistration
return DataSourceRegistration(self)
dataSource.__doc__ = PySparkSession.dataSource.__doc__
@property
def version(self) -> str:
result = self._client._analyze(method="spark_version").spark_version
assert result is not None
return result
version.__doc__ = PySparkSession.version.__doc__
@property
def client(self) -> "SparkConnectClient":
return self._client
client.__doc__ = PySparkSession.client.__doc__
def addArtifacts(
self, *path: str, pyfile: bool = False, archive: bool = False, file: bool = False
) -> None:
if sum([file, pyfile, archive]) > 1:
raise PySparkValueError(
error_class="INVALID_MULTIPLE_ARGUMENT_CONDITIONS",
message_parameters={
"arg_names": "'pyfile', 'archive' and/or 'file'",
"condition": "True together",
},
)
self._client.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file)
addArtifacts.__doc__ = PySparkSession.addArtifacts.__doc__
addArtifact = addArtifacts
def _cache_local_relation(self, local_relation: LocalRelation) -> str:
"""
Cache the local relation at the server side if it has not been cached yet.
"""
serialized = local_relation.serialize(self._client)
return self._client.cache_artifact(serialized)
def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
if urllib.parse.urlparse(dest_path).scheme:
raise PySparkValueError(
error_class="NO_SCHEMA_AND_DRIVER_DEFAULT_SCHEME",
message_parameters={"arg_name": "dest_path"},
)
self._client.copy_from_local_to_fs(local_path, dest_path)
copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__
def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame":
"""
In internal API to reference a runtime DataFrame on the server side.
This is used in ForeachBatch() runner, where the remote DataFrame refers to the
output of a micro batch.
"""
return DataFrame(CachedRemoteRelation(remote_id), self)
@staticmethod
def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
"""
Starts the Spark Connect server given the master (thread-unsafe).
At the high level, there are two cases. The first case is development case, e.g.,
you locally build Apache Spark, and run ``SparkSession.builder.remote("local")``:
1. This method automatically finds the jars for Spark Connect (because the jars for
Spark Connect are not bundled in the regular Apache Spark release).
2. Temporarily remove all states for Spark Connect, for example, ``SPARK_REMOTE``
environment variable.
3. Starts a JVM (without Spark Context) first, and adds the Spark Connect server jars
into the current class loader. Otherwise, Spark Context with ``spark.plugins``
cannot be initialized because the JVM is already running without the jars in
the classpath before executing this Python process for driver side (in case of
PySpark application submission).
4. Starts a regular Spark session that automatically starts a Spark Connect server
via ``spark.plugins`` feature.
The second case is when you use Apache Spark release:
1. Users must specify either the jars or package, e.g., ``--packages
org.apache.spark:spark-connect_2.12:3.4.0``. The jars or packages would be specified
in SparkSubmit automatically. This method does not do anything related to this.
2. Temporarily remove all states for Spark Connect, for example, ``SPARK_REMOTE``
environment variable. It does not do anything for PySpark application submission as
well because jars or packages were already specified before executing this Python
process for driver side.
3. Starts a regular Spark session that automatically starts a Spark Connect server
with JVM via ``spark.plugins`` feature.
"""
from pyspark import SparkContext, SparkConf, __version__
session = PySparkSession._instantiatedSession
if session is None or session._sc._jsc is None:
# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
# Configurations to be set if unset.
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
if "SPARK_TESTING" in os.environ:
# For testing, we use 0 to use an ephemeral port to allow parallel testing.
# See also SPARK-42272.
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
def create_conf(**kwargs: Any) -> SparkConf:
conf = SparkConf(**kwargs)
for k, v in overwrite_conf.items():
conf.set(k, v)
for k, v in default_conf.items():
if not conf.contains(k):
conf.set(k, v)
return conf
# Check if we're using unreleased version that is in development.
# Also checks SPARK_TESTING for RC versions.
is_dev_mode = (
"dev" in LooseVersion(__version__).version or "SPARK_TESTING" in os.environ
)
origin_remote = os.environ.get("SPARK_REMOTE", None)
try:
if origin_remote is not None:
# So SparkSubmit thinks no remote is set in order to
# start the regular PySpark session.
del os.environ["SPARK_REMOTE"]
SparkContext._ensure_initialized(conf=create_conf(loadDefaults=False))
if is_dev_mode:
# Try and catch for a possibility in production because pyspark.testing
# does not exist in the canonical release.
try:
from pyspark.testing.utils import search_jar
# Note that, in production, spark.jars.packages configuration should be
# set by users. Here we're automatically searching the jars locally built.
connect_jar = search_jar(
"connector/connect/server", "spark-connect-assembly-", "spark-connect"
)
if connect_jar is None:
warnings.warn(
"Attempted to automatically find the Spark Connect jars because "
"'SPARK_TESTING' environment variable is set, or the current "
f"PySpark version is dev version ({__version__}). However, the jar"
" was not found. Manually locate the jars and specify them, e.g., "
"'spark.jars' configuration."
)
else:
pyutils = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
pyutils.addJarToCurrentClassLoader(connect_jar)
# Required for local-cluster testing as their executors need the jars
# to load the Spark plugin for Spark Connect.
if master.startswith("local-cluster"):
if "spark.jars" in overwrite_conf:
overwrite_conf[
"spark.jars"
] = f"{overwrite_conf['spark.jars']},{connect_jar}"
else:
overwrite_conf["spark.jars"] = connect_jar
except ImportError:
pass
# The regular PySpark session is registered as an active session
# so would not be garbage-collected.
PySparkSession(
SparkContext.getOrCreate(create_conf(loadDefaults=True, _jvm=SparkContext._jvm))
)
# Lastly only keep runtime configurations because other configurations are
# disallowed to set in the regular Spark Connect session.
utl = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
runtime_conf_keys = [c._1() for c in utl.listRuntimeSQLConfigs()]
new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
opts.clear()
opts.update(new_opts)
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
else:
raise PySparkRuntimeError(
error_class="SESSION_OR_CONTEXT_EXISTS",
message_parameters={},
)
@property
def session_id(self) -> str:
return self._session_id
@property
def _profiler_collector(self) -> ProfilerCollector:
return self._client._profiler_collector
@property
def profile(self) -> Profile:
return Profile(self._client._profiler_collector)
profile.__doc__ = PySparkSession.profile.__doc__
def __reduce__(self) -> Tuple:
"""
This method is called when the object is pickled. It returns a tuple of the object's
constructor function, arguments to it and the local state of the object.
This function is supposed to only be used when the active spark session that is pickled
is the same active spark session that is unpickled.
"""
def creator(old_session_id: str) -> "SparkSession":
# We cannot perform the checks for session matching here because accessing the
# session ID property causes the serialization of a WeakRef and in turn breaks
# the serialization.
return SparkSession._getActiveSessionIfMatches(old_session_id)
return creator, (self._session_id,)
SparkSession.__doc__ = PySparkSession.__doc__
def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.session
globs = pyspark.sql.connect.session.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.session tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)
# Uses PySpark session to test builder.
globs["SparkSession"] = PySparkSession
# Spark Connect does not support to set master together.
pyspark.sql.connect.session.SparkSession.__doc__ = None
del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__
(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.session,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()