blob: 1d54a3ee1764cb415713e075d10839cbfcc6de48 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Worker status api handler for reporting SDK harness debug info."""
import gc
import logging
import queue
import sys
import threading
import time
import traceback
from collections import defaultdict
import grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.utils.sentinel import Sentinel
try:
from guppy import hpy
except ImportError:
hpy = None
_LOGGER = logging.getLogger(__name__)
# This SDK harness will (by default), log a "lull" in processing if it sees no
# transitions in over 5 minutes.
# 5 minutes * 60 seconds * 1000 millis * 1000 micros * 1000 nanoseconds
DEFAULT_LOG_LULL_TIMEOUT_NS = 5 * 60 * 1000 * 1000 * 1000
# Full thread dump is performed at most every 20 minutes.
LOG_LULL_FULL_THREAD_DUMP_INTERVAL_S = 20 * 60
# Full thread dump is performed if the lull is more than 20 minutes.
LOG_LULL_FULL_THREAD_DUMP_LULL_S = 20 * 60
def _current_frames():
# Work around https://github.com/python/cpython/issues/106883
if (sys.version_info.minor == 11 and sys.version_info.major == 3 and
gc.isenabled()):
gc.disable()
frames = sys._current_frames() # pylint: disable=protected-access
gc.enable()
return frames
else:
return sys._current_frames() # pylint: disable=protected-access
def thread_dump(thread_prefix=None):
"""Get a thread dump for the current SDK harness.
Args:
thread_prefix: (str) An optional prefix to filter threads by.
"""
# deduplicate threads with same stack trace
stack_traces = defaultdict(list)
frames = _current_frames()
threads_to_dump = threading.enumerate()
if thread_prefix:
threads_to_dump = [
t for t in threads_to_dump if t.name.startswith(thread_prefix)
]
for t in threads_to_dump:
try:
stack_trace = ''.join(traceback.format_stack(frames[t.ident]))
except KeyError:
# the thread may have been destroyed already while enumerating, in such
# case, skip to next thread.
continue
thread_ident_name = (t.ident, t.name)
stack_traces[stack_trace].append(thread_ident_name)
all_traces = ['=' * 10 + ' THREAD DUMP ' + '=' * 10]
for stack, identity in stack_traces.items():
ident, name = identity[0]
trace = '--- Thread #%s name: %s %s---\n' % (
ident,
name,
'and other %d threads' %
(len(identity) - 1) if len(identity) > 1 else '')
if len(identity) > 1:
trace += 'threads: %s\n' % identity
trace += stack
all_traces.append(trace)
all_traces.append('=' * 30)
return '\n'.join(all_traces)
def heap_dump():
"""Get a heap dump for the current SDK worker harness. """
banner = '=' * 10 + ' HEAP DUMP ' + '=' * 10 + '\n'
if not hpy:
heap = 'Unable to import guppy, the heap dump will be skipped.\n'
else:
heap = '%s\n' % hpy().heap()
ending = '=' * 30
return banner + heap + ending
def _state_cache_stats(state_cache: StateCache) -> str:
"""Gather state cache statistics."""
cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10]
if not state_cache.is_cache_enabled():
cache_stats.append("Cache disabled")
else:
cache_stats.append(state_cache.describe_stats())
return '\n'.join(cache_stats)
def _active_processing_bundles_state(bundle_processor_cache):
"""Gather information about the currently in-processing active bundles.
The result only keeps the longest lasting 10 bundles to avoid excessive
spamming.
"""
active_bundles = ['=' * 10 + ' ACTIVE PROCESSING BUNDLES ' + '=' * 10]
if (not bundle_processor_cache.active_bundle_processors and
not bundle_processor_cache.processors_being_created):
active_bundles.append("No active processing bundles.")
else:
cache = []
for instruction in list(
bundle_processor_cache.active_bundle_processors.keys()):
processor = bundle_processor_cache.lookup(instruction)
if processor:
info = processor.state_sampler.get_info()
cache.append((
instruction,
processor.process_bundle_descriptor.id,
info.tracked_thread,
info.time_since_transition))
# reverse sort active bundle by time since last transition, keep top 10.
cache.sort(key=lambda x: x[-1], reverse=True)
for s in cache[:10]:
state = '--- instruction %s ---\n' % s[0]
state += 'ProcessBundleDescriptorId: %s\n' % s[1]
state += "tracked thread: %s\n" % s[2]
state += "time since transition: %.2f seconds\n" % (s[3] / 1e9)
active_bundles.append(state)
if bundle_processor_cache.processors_being_created:
active_bundles.append("Processors being created:\n")
current_time = time.time()
for instruction, (bundle_id, thread, creation_time) in (
bundle_processor_cache.processors_being_created.items()):
state = '--- instruction %s ---\n' % instruction
state += 'ProcessBundleDescriptorId: %s\n' % bundle_id
state += "tracked thread: %s\n" % thread
state += "time since creation started: %.2f seconds\n" % (
current_time - creation_time)
active_bundles.append(state)
active_bundles.append('=' * 30)
return '\n'.join(active_bundles)
DONE = Sentinel.sentinel
class FnApiWorkerStatusHandler(object):
"""FnApiWorkerStatusHandler handles worker status request from Runner. """
def __init__(
self,
status_address,
bundle_processor_cache=None,
state_cache=None,
enable_heap_dump=False,
worker_id=None,
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS,
element_processing_timeout_minutes=None):
"""Initialize FnApiWorkerStatusHandler.
Args:
status_address: The URL Runner uses to host the WorkerStatus server.
bundle_processor_cache: The BundleProcessor cache dict from sdk worker.
state_cache: The StateCache form sdk worker.
"""
self._alive = True
self._bundle_processor_cache = bundle_processor_cache
self._state_cache = state_cache
ch = GRPCChannelFactory.insecure_channel(status_address)
grpc.channel_ready_future(ch).result(timeout=60)
self._status_channel = grpc.intercept_channel(
ch, WorkerIdInterceptor(worker_id))
self._status_stub = beam_fn_api_pb2_grpc.BeamFnWorkerStatusStub(
self._status_channel)
self._responses = queue.Queue()
self.log_lull_timeout_ns = log_lull_timeout_ns
if (element_processing_timeout_minutes and
element_processing_timeout_minutes > 0):
self._element_processing_timeout_ns = (
element_processing_timeout_minutes * 60 * 1e9)
else:
self._element_processing_timeout_ns = None
self._last_full_thread_dump_secs = 0.0
self._last_lull_logged_secs = 0.0
self._server = threading.Thread(
target=lambda: self._serve(), name='fn_api_status_handler')
self._server.daemon = True
self._enable_heap_dump = enable_heap_dump
self._server.start()
self._lull_logger = threading.Thread(
target=lambda: self._log_lull_in_bundle_processor(
self._bundle_processor_cache),
name='lull_operation_logger')
self._lull_logger.daemon = True
self._lull_logger.start()
def _get_responses(self):
while True:
response = self._responses.get()
if response is DONE:
self._alive = False
return
yield response
def _serve(self):
while self._alive:
for request in self._status_stub.WorkerStatus(self._get_responses()):
try:
self._responses.put(
beam_fn_api_pb2.WorkerStatusResponse(
id=request.id, status_info=self.generate_status_response()))
except Exception:
traceback_string = traceback.format_exc()
self._responses.put(
beam_fn_api_pb2.WorkerStatusResponse(
id=request.id,
error="Exception encountered while generating "
"status page: %s" % traceback_string))
def generate_status_response(self):
all_status_sections = []
if self._state_cache:
all_status_sections.append(_state_cache_stats(self._state_cache))
if self._bundle_processor_cache:
all_status_sections.append(
_active_processing_bundles_state(self._bundle_processor_cache))
all_status_sections.append(thread_dump())
if self._enable_heap_dump:
all_status_sections.append(heap_dump())
return '\n'.join(all_status_sections)
def close(self):
self._responses.put(DONE, timeout=5)
def _log_lull_in_bundle_processor(self, bundle_processor_cache):
while True:
time.sleep(2 * 60)
if not bundle_processor_cache:
continue
for instruction in list(
bundle_processor_cache.active_bundle_processors.keys()):
processor = bundle_processor_cache.lookup(instruction)
if processor:
info = processor.state_sampler.get_info()
self._log_lull_sampler_info(info, instruction)
for instruction, (bundle_id, thread, creation_time) in list(
bundle_processor_cache.processors_being_created.items()):
self._log_lull_in_creating_bundle_descriptor(
instruction, bundle_id, thread, creation_time)
def _log_lull_in_creating_bundle_descriptor(
self, instruction, bundle_id, thread, creation_time):
time_since_creation_ns = (time.time() - creation_time) * 1e9
if (self._element_processing_timeout_ns and
time_since_creation_ns > self._element_processing_timeout_ns):
stack_trace = self._get_stack_trace(thread)
_LOGGER.error((
'Creation of bundle processor for instruction %s (bundle %s) '
'has exceeded the specified timeout of %.2f minutes. '
'This might indicate stuckness in DoFn.setup() or in DoFn creation. '
'SDK harness will be terminated.\n'
'Current Traceback:\n%s'),
instruction,
bundle_id,
self._element_processing_timeout_ns / 1e9 / 60,
stack_trace)
from apache_beam.runners.worker.sdk_worker_main import terminate_sdk_harness
terminate_sdk_harness()
if (time_since_creation_ns > self.log_lull_timeout_ns and
self._passed_lull_timeout_since_last_log()):
stack_trace = self._get_stack_trace(thread)
_LOGGER.warning((
'Bundle processor for instruction %s (bundle %s) '
'has been creating for at least %.2f seconds.\n'
'This might indicate slowness in DoFn.setup() or in DoFn creation. '
'Current Traceback:\n%s'),
instruction,
bundle_id,
time_since_creation_ns / 1e9,
stack_trace)
def _log_lull_sampler_info(self, sampler_info, instruction):
if (not sampler_info or not sampler_info.time_since_transition):
return
log_lull = (
sampler_info.time_since_transition > self.log_lull_timeout_ns and
self._passed_lull_timeout_since_last_log())
timeout_exceeded = (
self._element_processing_timeout_ns and
sampler_info.time_since_transition
> self._element_processing_timeout_ns)
if not (log_lull or timeout_exceeded):
return
lull_seconds = sampler_info.time_since_transition / 1e9
step_name = sampler_info.state_name.step_name
state_name = sampler_info.state_name.name
if step_name and state_name:
step_name_log = (
' for PTransform{name=%s, state=%s}' % (step_name, state_name))
else:
step_name_log = ''
stack_trace = self._get_stack_trace(sampler_info.tracked_thread)
if timeout_exceeded:
_LOGGER.error(
(
'Processing of an element in bundle %s%s has exceeded the '
'specified timeout of %.2f minutes. SDK harness will be '
'terminated.\n'
'Current Traceback:\n%s'),
instruction,
step_name_log,
self._element_processing_timeout_ns / 1e9 / 60,
stack_trace,
)
from apache_beam.runners.worker.sdk_worker_main import terminate_sdk_harness
terminate_sdk_harness()
if log_lull:
_LOGGER.warning(
(
'Operation ongoing in bundle %s%s for at least %.2f seconds'
' without outputting or completing.\n'
'Current Traceback:\n%s'),
instruction,
step_name_log,
lull_seconds,
stack_trace,
)
def _get_stack_trace(self, thread):
if thread:
thread_frame = _current_frames().get(thread.ident)
return '\n'.join(
traceback.format_stack(thread_frame)) if thread_frame else ''
else:
return '-NOT AVAILABLE-'
def _passed_lull_timeout_since_last_log(self) -> bool:
if (time.time() - self._last_lull_logged_secs
> self.log_lull_timeout_ns / 1e9):
self._last_lull_logged_secs = time.time()
return True
else:
return False