blob: ad1fc810f3ef4937841e7fce3947ab5f2ce0f483 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# type: ignore
from contextlib import closing
import time
import socket
import socketserver
from struct import pack, unpack
import sys
import threading
import traceback
from typing import Generator
import warnings
# Use b'\x00' as separator instead of b'\n', because the bytes are encoded in utf-8
_SERVER_POLL_INTERVAL = 0.1
_TRUNCATE_MSG_LEN = 4000
_log_print_lock = threading.Lock() # pylint: disable=invalid-name
def _get_log_print_lock() -> threading.Lock:
return _log_print_lock
class WriteLogToStdout(socketserver.StreamRequestHandler):
def _read_bline(self) -> Generator[bytes, None, None]:
while self.server.is_active:
packed_number_bytes = self.rfile.read(4)
if not packed_number_bytes:
time.sleep(_SERVER_POLL_INTERVAL)
continue
number_bytes = unpack(">i", packed_number_bytes)[0]
message = self.rfile.read(number_bytes)
yield message
def handle(self) -> None:
self.request.setblocking(0) # non-blocking mode
for bline in self._read_bline():
with _get_log_print_lock():
sys.stderr.write(bline.decode("utf-8") + "\n")
sys.stderr.flush()
# What is run on the local driver
class LogStreamingServer:
def __init__(self) -> None:
self.server = None
self.serve_thread = None
self.port = None
@staticmethod
def _get_free_port(spark_host_address: str = "") -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as tcp:
tcp.bind((spark_host_address, 0))
_, port = tcp.getsockname()
return port
def start(self, spark_host_address: str = "") -> None:
if self.server:
raise RuntimeError("Cannot start the server twice.")
def serve_task(port: int) -> None:
with socketserver.ThreadingTCPServer(("0.0.0.0", port), WriteLogToStdout) as server:
self.server = server
server.is_active = True
server.serve_forever(poll_interval=_SERVER_POLL_INTERVAL)
self.port = LogStreamingServer._get_free_port(spark_host_address)
self.serve_thread = threading.Thread(target=serve_task, args=(self.port,))
self.serve_thread.daemon = True
self.serve_thread.start()
def shutdown(self) -> None:
if self.server:
# Sleep to ensure all log has been received and printed.
time.sleep(_SERVER_POLL_INTERVAL * 2)
# Before close we need flush to ensure all stdout buffer were printed.
sys.stdout.flush()
self.server.is_active = False
self.server.shutdown()
self.serve_thread.join()
self.server = None
self.serve_thread = None
class LogStreamingClientBase:
@staticmethod
def _maybe_truncate_msg(message: str) -> str:
if len(message) > _TRUNCATE_MSG_LEN:
message = message[:_TRUNCATE_MSG_LEN]
return message + "...(truncated)"
else:
return message
def send(self, message: str) -> None:
pass
def close(self) -> None:
pass
class LogStreamingClient(LogStreamingClientBase):
"""
A client that streams log messages to :class:`LogStreamingServer`.
In case of failures, the client will skip messages instead of raising an error.
"""
_log_callback_client = None
_server_address = None
_singleton_lock = threading.Lock()
@staticmethod
def _init(address: str, port: int) -> None:
LogStreamingClient._server_address = (address, port)
@staticmethod
def _destroy() -> None:
LogStreamingClient._server_address = None
if LogStreamingClient._log_callback_client is not None:
LogStreamingClient._log_callback_client.close()
def __init__(self, address: str, port: int, timeout: int = 10):
"""
Creates a connection to the logging server and authenticates.This client is best effort,
if authentication or sending a message fails, the client will be marked as not alive and
stop trying to send message.
:param address: Address where the service is running.
:param port: Port where the service is listening for new connections.
"""
self.address = address
self.port = port
self.timeout = timeout
self.sock = None
self.failed = True
self._lock = threading.RLock()
def _fail(self, error_msg: str) -> None:
self.failed = True
warnings.warn(f"{error_msg}: {traceback.format_exc()}\n")
def _connect(self) -> None:
if self.port == -1:
self._fail("Log streaming server is not available.")
return
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect((self.address, self.port))
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.sock = sock
self.failed = False
except (OSError, IOError): # pylint: disable=broad-except
self._fail("Error connecting log streaming server")
def send(self, message: str) -> None:
"""
Sends a message.
"""
with self._lock:
if self.sock is None:
self._connect()
if not self.failed:
try:
message = LogStreamingClientBase._maybe_truncate_msg(message)
# TODO:
# 1) addressing issue: idle TCP connection might get disconnected by
# cloud provider
# 2) sendall may block when server is busy handling data.
binary_message = message.encode("utf-8")
packed_number_bytes = pack(">i", len(binary_message))
self.sock.sendall(packed_number_bytes + binary_message)
except Exception: # pylint: disable=broad-except
self._fail("Error sending logs to driver, stopping log streaming")
def close(self) -> None:
"""
Closes the connection.
"""
if self.sock:
self.sock.close()
self.sock = None