blob: 79684b7aca624b4305a83572e74ab9206a69827f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from contextlib import contextmanager
import inspect
import io
import logging
import os
import sys
import time
from typing import BinaryIO, Callable, Generator, Iterable, Iterator, Optional, TextIO, Union
from types import FrameType, TracebackType
from pyspark.logger.logger import JSONFormatter
class DelegatingTextIOWrapper(TextIO):
"""A TextIO that delegates all operations to another TextIO object."""
def __init__(self, delegate: TextIO):
self._delegate = delegate
# Required TextIO properties
@property
def encoding(self) -> str:
return self._delegate.encoding
@property
def errors(self) -> Optional[str]:
return self._delegate.errors
@property
def newlines(self) -> Optional[Union[str, tuple[str, ...]]]:
return self._delegate.newlines
@property
def buffer(self) -> BinaryIO:
return self._delegate.buffer
@property
def mode(self) -> str:
return self._delegate.mode
@property
def name(self) -> str:
return self._delegate.name
@property
def line_buffering(self) -> int:
return self._delegate.line_buffering
@property
def closed(self) -> bool:
return self._delegate.closed
# Iterator protocol
def __iter__(self) -> Iterator[str]:
return iter(self._delegate)
def __next__(self) -> str:
return next(self._delegate)
# Context manager protocol
def __enter__(self) -> TextIO:
return self._delegate.__enter__()
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
return self._delegate.__exit__(exc_type, exc_val, exc_tb)
# Core I/O methods
def write(self, s: str) -> int:
return self._delegate.write(s)
def writelines(self, lines: Iterable[str]) -> None:
return self._delegate.writelines(lines)
def read(self, size: int = -1) -> str:
return self._delegate.read(size)
def readline(self, size: int = -1) -> str:
return self._delegate.readline(size)
def readlines(self, hint: int = -1) -> list[str]:
return self._delegate.readlines(hint)
# Stream control methods
def close(self) -> None:
return self._delegate.close()
def flush(self) -> None:
return self._delegate.flush()
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
return self._delegate.seek(offset, whence)
def tell(self) -> int:
return self._delegate.tell()
def truncate(self, size: Optional[int] = None) -> int:
return self._delegate.truncate(size)
# Stream capability methods
def fileno(self) -> int:
return self._delegate.fileno()
def isatty(self) -> bool:
return self._delegate.isatty()
def readable(self) -> bool:
return self._delegate.readable()
def seekable(self) -> bool:
return self._delegate.seekable()
def writable(self) -> bool:
return self._delegate.writable()
class JSONFormatterWithMarker(JSONFormatter):
default_microsec_format = "%s.%06d"
def __init__(self, marker: str, worker_id: str, context_provider: Callable[[], dict[str, str]]):
super().__init__(ensure_ascii=True)
self._marker = marker
self._worker_id = worker_id
self._context_provider = context_provider
def format(self, record: logging.LogRecord) -> str:
context = self._context_provider()
if context:
context.update(record.__dict__.get("context", {}))
record.__dict__["context"] = context
return f"{self._marker}:{self._worker_id}:{super().format(record)}"
def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str:
ct = self.converter(record.created)
if datefmt:
s = time.strftime(datefmt, ct)
else:
s = time.strftime(self.default_time_format, ct)
if self.default_microsec_format:
s = self.default_microsec_format % (
s,
int((record.created - int(record.created)) * 1000000),
)
elif self.default_msec_format:
s = self.default_msec_format % (s, record.msecs)
s = f"{s}{time.strftime('%z', ct)}"
return s
class JsonOutput(DelegatingTextIOWrapper):
def __init__(
self,
delegate: TextIO,
json_out: TextIO,
logger_name: str,
log_level: int,
marker: str,
worker_id: str,
context_provider: Callable[[], dict[str, str]],
):
super().__init__(delegate)
self._json_out = json_out
self._logger_name = logger_name
self._log_level = log_level
self._formatter = JSONFormatterWithMarker(marker, worker_id, context_provider)
def write(self, s: str) -> int:
if s.strip():
log_record = logging.LogRecord(
name=self._logger_name,
level=self._log_level,
pathname=None, # type: ignore[arg-type]
lineno=None, # type: ignore[arg-type]
msg=s.strip(),
args=None,
exc_info=None,
func=None,
sinfo=None,
)
self._json_out.write(f"{self._formatter.format(log_record)}\n")
self._json_out.flush()
return self._delegate.write(s)
def writelines(self, lines: Iterable[str]) -> None:
# Process each line through our JSON logging logic
for line in lines:
self.write(line)
def close(self) -> None:
pass
def context_provider() -> dict[str, str]:
"""
Provides context information for logging, including caller function name.
Finds the function name from the bottom of the stack, ignoring Python builtin
libraries and PySpark modules. Test packages are included.
Returns:
dict[str, str]: A dictionary containing context information including:
- func_name: Name of the function that initiated the logging
- class_name: Name of the class that initiated the logging if available
"""
def is_pyspark_module(module_name: str) -> bool:
return module_name.startswith("pyspark.") and ".tests." not in module_name
bottom: Optional[FrameType] = None
# Get caller function information using inspect
try:
frame = inspect.currentframe()
is_in_pyspark_module = False
if frame:
while frame.f_back:
f_back = frame.f_back
module_name = f_back.f_globals.get("__name__", "")
if is_pyspark_module(module_name):
if not is_in_pyspark_module:
bottom = frame
is_in_pyspark_module = True
else:
is_in_pyspark_module = False
frame = f_back
except Exception:
# If anything goes wrong with introspection, don't fail the logging
# Just continue without caller information
pass
context = {}
if bottom:
context["func_name"] = bottom.f_code.co_name
if "self" in bottom.f_locals:
context["class_name"] = bottom.f_locals["self"].__class__.__name__
elif "cls" in bottom.f_locals:
context["class_name"] = bottom.f_locals["cls"].__name__
return context
@contextmanager
def capture_outputs(
context_provider: Callable[[], dict[str, str]] = context_provider
) -> Generator[None, None, None]:
if "PYSPARK_SPARK_SESSION_UUID" in os.environ:
marker: str = "PYTHON_WORKER_LOGGING"
worker_id: str = str(os.getpid())
json_out = original_stdout = sys.stdout
delegate = original_stderr = sys.stderr
handler = logging.StreamHandler(json_out)
handler.setFormatter(JSONFormatterWithMarker(marker, worker_id, context_provider))
logger = logging.getLogger()
try:
sys.stdout = JsonOutput(
delegate, json_out, "stdout", logging.INFO, marker, worker_id, context_provider
)
sys.stderr = JsonOutput(
delegate, json_out, "stderr", logging.ERROR, marker, worker_id, context_provider
)
logger.addHandler(handler)
try:
yield
finally:
# Send an empty line to indicate the end of the outputs.
json_out.write(f"{marker}:{worker_id}:\n")
json_out.flush()
finally:
sys.stdout = original_stdout
sys.stderr = original_stderr
logger.removeHandler(handler)
handler.close()
else:
yield