blob: c2f35db8d52de4bc4ce53c0ff00f1a597aa98bcb [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Util functions for workers.
"""
import importlib
from inspect import currentframe, getframeinfo
import os
import sys
from typing import Any, IO
import warnings
# 'resource' is a Unix specific module.
has_resource_module = True
try:
import resource
except ImportError:
has_resource_module = False
from pyspark.accumulators import _accumulatorRegistry
from pyspark.util import is_remote_only
from pyspark.errors import PySparkRuntimeError
from pyspark.util import local_connect_and_auth
from pyspark.serializers import (
read_bool,
read_int,
read_long,
write_int,
FramedSerializer,
UTF8Deserializer,
CPickleSerializer,
)
pickleSer = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
def add_path(path: str) -> None:
# worker can be used, so do not add path multiple times
if path not in sys.path:
# overwrite system packages
sys.path.insert(1, path)
def read_command(serializer: FramedSerializer, file: IO) -> Any:
if not is_remote_only():
from pyspark.core.broadcast import Broadcast
command = serializer._read_with_length(file)
if not is_remote_only() and isinstance(command, Broadcast):
command = serializer.loads(command.value)
return command
def check_python_version(infile: IO) -> None:
"""
Check the Python version between the running process and the one used to serialize the command.
"""
version = utf8_deserializer.loads(infile)
worker_version = "%d.%d" % sys.version_info[:2]
if version != worker_version:
raise PySparkRuntimeError(
errorClass="PYTHON_VERSION_MISMATCH",
messageParameters={
"worker_version": worker_version,
"driver_version": str(version),
},
)
def setup_memory_limits(memory_limit_mb: int) -> None:
"""
Sets up the memory limits.
If memory_limit_mb > 0 and `resource` module is available, sets the memory limit.
Windows does not support resource limiting and actual resource is not limited on MacOS.
"""
if memory_limit_mb > 0 and has_resource_module:
total_memory = resource.RLIMIT_AS
try:
(soft_limit, hard_limit) = resource.getrlimit(total_memory)
msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
print(msg, file=sys.stderr)
# convert to bytes
new_limit = memory_limit_mb * 1024 * 1024
if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
print(msg, file=sys.stderr)
resource.setrlimit(total_memory, (new_limit, new_limit))
except (resource.error, OSError, ValueError) as e:
# not all systems support resource limits, so warn instead of failing
current = currentframe()
lineno = getframeinfo(current).lineno + 1 if current is not None else 0
if "__file__" in globals():
print(
warnings.formatwarning(
"Failed to set memory limit: {0}".format(e),
ResourceWarning,
__file__,
lineno,
),
file=sys.stderr,
)
def setup_spark_files(infile: IO) -> None:
"""
Set up Spark files, archives, and pyfiles.
"""
# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
if not is_remote_only():
from pyspark.core.files import SparkFiles
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
add_path(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
add_path(os.path.join(spark_files_dir, filename))
importlib.invalidate_caches()
def setup_broadcasts(infile: IO) -> None:
"""
Set up broadcasted variables.
"""
if not is_remote_only():
from pyspark.core.broadcast import Broadcast, _broadcastRegistry
# fetch names and values of broadcast variables
needs_broadcast_decryption_server = read_bool(infile)
num_broadcast_variables = read_int(infile)
if needs_broadcast_decryption_server:
# read the decrypted data from a server in the jvm
conn_info = read_int(infile)
auth_secret = None
if conn_info == -1:
conn_info = utf8_deserializer.loads(infile)
else:
auth_secret = utf8_deserializer.loads(infile)
(broadcast_sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
if needs_broadcast_decryption_server:
read_bid = read_long(broadcast_sock_file)
assert read_bid == bid
_broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file)
else:
path = utf8_deserializer.loads(infile)
_broadcastRegistry[bid] = Broadcast(path=path)
else:
bid = -bid - 1
_broadcastRegistry.pop(bid)
if needs_broadcast_decryption_server:
broadcast_sock_file.write(b"1")
broadcast_sock_file.close()
def send_accumulator_updates(outfile: IO) -> None:
"""
Send the accumulator updates back to JVM.
"""
write_int(len(_accumulatorRegistry), outfile)
for aid, accum in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)