blob: dd243698136ecbc37066f9364324d206e96faaab [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.connect.client.logging import logger
check_dependencies(__name__)
import hashlib
import importlib
import io
import sys
import os
import zlib
from itertools import chain
from typing import List, Iterable, BinaryIO, Iterator, Optional, Tuple
import abc
from pathlib import Path
from urllib.parse import urlparse
from urllib.request import url2pathname
from functools import cached_property
import grpc
import pyspark.sql.connect.proto as proto
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
JAR_PREFIX: str = "jars"
PYFILE_PREFIX: str = "pyfiles"
ARCHIVE_PREFIX: str = "archives"
FILE_PREFIX: str = "files"
FORWARD_TO_FS_PREFIX: str = "forward_to_fs"
CACHE_PREFIX: str = "cache"
class LocalData(metaclass=abc.ABCMeta):
"""
Payload stored on this machine.
"""
@abc.abstractmethod
def stream(self) -> BinaryIO:
pass
@cached_property
@abc.abstractmethod
def size(self) -> int:
pass
class LocalFile(LocalData):
"""
Payload stored in a local file.
"""
def __init__(self, path: str):
self.path = path
# Check that the file can be read
# so that incorrect references can be discovered during Artifact creation,
# and not at the point of consumption.
with self.stream():
pass
@cached_property
def size(self) -> int:
return os.path.getsize(self.path)
def stream(self) -> BinaryIO:
return open(self.path, "rb")
class InMemory(LocalData):
"""
Payload stored in memory.
"""
def __init__(self, blob: bytes):
self.blob = blob
@cached_property
def size(self) -> int:
return len(self.blob)
def stream(self) -> BinaryIO:
return io.BytesIO(self.blob)
class Artifact:
def __init__(self, path: str, storage: LocalData):
assert not Path(path).is_absolute(), f"Bad path: {path}"
self.path = path
self.storage = storage
@cached_property
def size(self) -> int:
if isinstance(self.storage, LocalData):
return self.storage.size
else:
raise PySparkRuntimeError(
errorClass="UNSUPPORTED_OPERATION",
messageParameters={"operation": f"{self.storage} storage"},
)
def new_jar_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(JAR_PREFIX, ".jar", file_name, storage)
def new_pyfile_artifact(file_name: str, storage: LocalData) -> Artifact:
assert any(file_name.endswith(s) for s in (".py", ".zip", ".egg", ".jar"))
return _new_artifact(PYFILE_PREFIX, "", file_name, storage)
def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage)
def new_file_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(FILE_PREFIX, "", file_name, storage)
def new_cache_artifact(id: str, storage: LocalData) -> Artifact:
return _new_artifact(CACHE_PREFIX, "", id, storage)
def _new_artifact(
prefix: str, required_suffix: str, file_name: str, storage: LocalData
) -> Artifact:
assert not Path(file_name).is_absolute()
assert file_name.endswith(required_suffix)
return Artifact(os.path.join(prefix, file_name), storage)
class ArtifactManager:
"""
The Artifact Manager is responsible for handling and transferring artifacts from the local
client to the server (local/remote).
Parameters
----------
user_id : str, optional
User ID.
session_id: str
An unique identifier of the session which the artifact manager belongs to.
channel: grpc.Channel
GRPC Channel instance.
"""
# Using the midpoint recommendation of 32KiB for chunk size as specified in
# https://github.com/grpc/grpc.github.io/issues/371.
CHUNK_SIZE: int = 32 * 1024
def __init__(
self,
user_id: Optional[str],
session_id: str,
channel: grpc.Channel,
metadata: Iterable[Tuple[str, str]],
):
self._user_context = proto.UserContext()
if user_id is not None:
self._user_context.user_id = user_id
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id
self._metadata = metadata
def _parse_artifacts(
self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
) -> List[Artifact]:
# Currently only local files with .jar extension is supported.
parsed = urlparse(path_or_uri)
# Check if it is a file from the scheme
if parsed.scheme == "":
# Similar with Utils.resolveURI.
fragment = parsed.fragment
parsed = urlparse(Path(url2pathname(parsed.path)).absolute().as_uri())
parsed = parsed._replace(fragment=fragment)
if parsed.scheme == "file":
local_path = url2pathname(parsed.path)
name = Path(local_path).name
if pyfile and name.endswith(".py"):
artifact = new_pyfile_artifact(name, LocalFile(local_path))
importlib.invalidate_caches()
elif pyfile and (
name.endswith(".zip") or name.endswith(".egg") or name.endswith(".jar")
):
sys.path.insert(1, local_path)
artifact = new_pyfile_artifact(name, LocalFile(local_path))
importlib.invalidate_caches()
elif archive and (
name.endswith(".zip")
or name.endswith(".jar")
or name.endswith(".tar.gz")
or name.endswith(".tgz")
or name.endswith(".tar")
):
assert any(name.endswith(s) for s in (".zip", ".jar", ".tar.gz", ".tgz", ".tar"))
if parsed.fragment != "":
# Minimal fix for the workaround of fragment handling in URI.
# This has a limitation - hash(#) in the file name would not work.
if "#" in local_path:
raise PySparkValueError(
errorClass="VALUE_ALLOWED",
messageParameters={
"arg_name": "artifact path",
"disallowed_value": "#",
},
)
name = f"{name}#{parsed.fragment}"
artifact = new_archive_artifact(name, LocalFile(local_path))
elif file:
artifact = new_file_artifact(name, LocalFile(local_path))
elif name.endswith(".jar"):
artifact = new_jar_artifact(name, LocalFile(local_path))
else:
raise PySparkRuntimeError(
errorClass="UNSUPPORTED_OPERATION",
messageParameters={"operation": f"{local_path} file format"},
)
return [artifact]
raise PySparkRuntimeError(
errorClass="UNSUPPORTED_OPERATION",
messageParameters={"operation": f"{parsed.scheme} scheme"},
)
def _parse_forward_to_fs_artifacts(self, local_path: str, dest_path: str) -> List[Artifact]:
abs_path: Path = Path(local_path).absolute()
# TODO: Support directory path.
assert abs_path.is_file(), "local path must be a file path."
storage = LocalFile(str(abs_path))
assert Path(dest_path).is_absolute(), "destination FS path must be an absolute path."
# The `dest_path` is an absolute path, to add the FORWARD_TO_FS_PREFIX,
# we cannot use `os.path.join`
artifact_path = FORWARD_TO_FS_PREFIX + dest_path
return [Artifact(artifact_path, storage)]
def _create_requests(
self, *path: str, pyfile: bool, archive: bool, file: bool
) -> Iterator[proto.AddArtifactsRequest]:
"""Separated for the testing purpose."""
# It's crucial that this function is not generator, but only returns generator.
# This way we are doing artifact parsing within the original caller thread
# And not during grpc consuming iterator, allowing for much better error reporting.
artifacts: Iterator[Artifact] = chain(
*(self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) for p in path)
)
def generator() -> Iterator[proto.AddArtifactsRequest]:
try:
yield from self._add_artifacts(artifacts)
except Exception as e:
logger.error(f"Failed to submit addArtifacts request: {e}")
raise
return generator()
def _retrieve_responses(
self, requests: Iterator[proto.AddArtifactsRequest]
) -> proto.AddArtifactsResponse:
"""Separated for the testing purpose."""
return self._stub.AddArtifacts(requests, metadata=self._metadata)
def _request_add_artifacts(self, requests: Iterator[proto.AddArtifactsRequest]) -> None:
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
for summary in response.artifacts:
summaries.append(summary)
# TODO(SPARK-42658): Handle responses containing CRC failures.
def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None:
"""
Add a single artifact to the session.
Currently only local files with .jar extension is supported.
"""
requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
*path, pyfile=pyfile, archive=archive, file=file
)
self._request_add_artifacts(requests)
def _add_forward_to_fs_artifacts(self, local_path: str, dest_path: str) -> None:
requests: Iterator[proto.AddArtifactsRequest] = self._add_artifacts(
self._parse_forward_to_fs_artifacts(local_path, dest_path)
)
self._request_add_artifacts(requests)
def _add_artifacts(self, artifacts: Iterable[Artifact]) -> Iterator[proto.AddArtifactsRequest]:
"""
Add a number of artifacts to the session.
"""
current_batch: List[Artifact] = []
current_batch_size = 0
def add_to_batch(dep: Artifact, size: int) -> None:
nonlocal current_batch
nonlocal current_batch_size
current_batch.append(dep)
current_batch_size += size
def write_batch() -> Iterator[proto.AddArtifactsRequest]:
nonlocal current_batch
nonlocal current_batch_size
yield from self._add_batched_artifacts(current_batch)
current_batch = []
current_batch_size = 0
for artifact in artifacts:
data = artifact.storage
size = data.size
if size > ArtifactManager.CHUNK_SIZE:
# Payload can either be a batch OR a single chunked artifact.
# Write batch if non-empty before chunking current artifact.
if len(current_batch) > 0:
yield from write_batch()
yield from self._add_chunked_artifact(artifact)
else:
if current_batch_size + size > ArtifactManager.CHUNK_SIZE:
yield from write_batch()
add_to_batch(artifact, size)
if len(current_batch) > 0:
yield from write_batch()
def _add_batched_artifacts(
self, artifacts: Iterable[Artifact]
) -> Iterator[proto.AddArtifactsRequest]:
"""
Add a batch of artifacts to the stream. All the artifacts in this call are packaged into a
single :class:`proto.AddArtifactsRequest`.
"""
artifact_chunks = []
for artifact in artifacts:
with artifact.storage.stream() as stream:
binary = stream.read()
crc32 = zlib.crc32(binary)
data = proto.AddArtifactsRequest.ArtifactChunk(data=binary, crc=crc32)
artifact_chunks.append(
proto.AddArtifactsRequest.SingleChunkArtifact(name=artifact.path, data=data)
)
# Write the request once
yield proto.AddArtifactsRequest(
session_id=self._session_id,
user_context=self._user_context,
batch=proto.AddArtifactsRequest.Batch(artifacts=artifact_chunks),
)
def _add_chunked_artifact(self, artifact: Artifact) -> Iterator[proto.AddArtifactsRequest]:
"""
Add a artifact in chunks to the stream. The artifact's data is spread out over multiple
:class:`proto.AddArtifactsRequest requests`.
"""
initial_batch = True
# Integer division that rounds up to the nearest whole number.
get_num_chunks = int(
(artifact.size + (ArtifactManager.CHUNK_SIZE - 1)) / ArtifactManager.CHUNK_SIZE
)
# Consume stream in chunks until there is no data left to read.
with artifact.storage.stream() as stream:
for chunk in iter(lambda: stream.read(ArtifactManager.CHUNK_SIZE), b""):
if initial_batch:
# First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`).
yield proto.AddArtifactsRequest(
session_id=self._session_id,
user_context=self._user_context,
begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact(
name=artifact.path,
total_bytes=artifact.size,
num_chunks=get_num_chunks,
initial_chunk=proto.AddArtifactsRequest.ArtifactChunk(
data=chunk, crc=zlib.crc32(chunk)
),
),
)
initial_batch = False
else:
# Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`).
yield proto.AddArtifactsRequest(
session_id=self._session_id,
user_context=self._user_context,
chunk=proto.AddArtifactsRequest.ArtifactChunk(
data=chunk, crc=zlib.crc32(chunk)
),
)
def is_cached_artifact(self, hash: str) -> bool:
"""
Ask the server either any artifact with `hash` has been cached at the server side or not.
"""
artifactName = CACHE_PREFIX + "/" + hash
request = proto.ArtifactStatusesRequest(
user_context=self._user_context, session_id=self._session_id, names=[artifactName]
)
resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
request, metadata=self._metadata
)
status = resp.statuses.get(artifactName)
return status.exists if status is not None else False
def cache_artifact(self, blob: bytes) -> str:
"""
Cache the give blob at the session.
"""
hash = hashlib.sha256(blob).hexdigest()
if not self.is_cached_artifact(hash):
requests = self._add_artifacts([new_cache_artifact(hash, InMemory(blob))])
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
for summary in response.artifacts:
summaries.append(summary)
# TODO(SPARK-42658): Handle responses containing CRC failures.
return hash