blob: 5ca307ca1b3761c5ba0853a58ccb26f30114e675 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Functionaliry for sampling elements during bundle execution."""
# pytype: skip-file
from __future__ import annotations
import collections
import logging
import threading
import time
import traceback
from dataclasses import dataclass
from threading import Timer
from typing import Any
from typing import Deque
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.coders.coder_impl import WindowedValueCoderImpl
from apache_beam.coders.coders import Coder
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.utils.windowed_value import WindowedValue
_LOGGER = logging.getLogger(__name__)
class SampleTimer:
"""Periodic timer for sampling elements."""
def __init__(self, timeout_secs: float, sampler: OutputSampler) -> None:
self._target_timeout_secs = timeout_secs
self._timeout_secs = min(timeout_secs, 0.5) if timeout_secs > 0 else 0.0
self._timer = Timer(self._timeout_secs, self.sample)
self._sampler = sampler
self._sample_duration_secs = 0.0
def reset(self) -> None:
# For the first 30 seconds, sample every 0.5 seconds. After that, sample at
# the normal rate.
if self._sample_duration_secs >= 30.0:
self._timeout_secs = self._target_timeout_secs
self._sample_duration_secs += self._timeout_secs
self._timer.cancel()
self._timer = Timer(self._timeout_secs, self.sample)
self._timer.start()
def stop(self) -> None:
self._timer.cancel()
def sample(self) -> None:
self._sampler.sample()
self.reset()
@dataclass
class ExceptionMetadata:
# The repr-ified Exception.
msg: str
# The transform where the exception occured.
transform_id: str
# The instruction when the exception occured.
instruction_id: str
@dataclass
class ElementSampler:
"""Record class to hold sampled elements.
This class is used as an optimization to quickly sample elements. This is a
shared reference between the Operation and the OutputSampler.
"""
# Is true iff the `el` has been set with a sample.
has_element: bool = False
# The sampled element. Note that `None` is a valid element and cannot be uesd
# as a sentintel to check if there is a sample. Use the `has_element` flag to
# check for this case.
el: Any = None
class OutputSampler:
"""Represents a way to sample an output of a PTransform.
This is configurable to only keep `max_samples` (see constructor) sampled
elements in memory. Samples are taken every `sample_every_sec`.
"""
def __init__(
self,
coder: Coder,
max_samples: int = 10,
sample_every_sec: float = 5) -> None:
self._samples: Deque[Any] = collections.deque(maxlen=max_samples)
self._samples_lock: threading.Lock = threading.Lock()
self._coder_impl: CoderImpl = coder.get_impl()
self._sample_timer = SampleTimer(sample_every_sec, self)
self.element_sampler = ElementSampler()
self.element_sampler.has_element = False
self._exceptions: Deque[Tuple[Any, ExceptionMetadata]] = collections.deque(
maxlen=max_samples)
# For testing, it's easier to disable the Timer and manually sample.
if sample_every_sec > 0:
self._sample_timer.reset()
def stop(self) -> None:
"""Stops sampling."""
self._sample_timer.stop()
def remove_windowed_value(self, el: Union[WindowedValue, Any]) -> Any:
"""Retrieves the value from the WindowedValue.
The Python SDK passes elements as WindowedValues, which may not match the
coder for that particular PCollection.
"""
while isinstance(el, WindowedValue):
el = el.value
return el
def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]:
"""Returns all samples and optionally clears buffer if clear is True."""
with self._samples_lock:
# TODO(rohdesamuel): There can duplicates between the exceptions and
# samples. This happens when the OutputSampler samples during an
# exception. The fix is to create a OutputSampler per process bundle.
# Until then use a set to keep track of the elements.
seen = set(id(el) for el, _ in self._exceptions)
if isinstance(self._coder_impl, WindowedValueCoderImpl):
exceptions = [s for s in self._exceptions]
samples = [s for s in self._samples if id(s) not in seen]
else:
exceptions = [
(self.remove_windowed_value(a), b) for a, b in self._exceptions
]
samples = [
self.remove_windowed_value(s) for s in self._samples
if id(s) not in seen
]
# Encode in the nested context b/c this ensures that the SDK can decode
# the bytes with the ToStringFn.
if clear:
self._samples.clear()
self._exceptions.clear()
ret = [
beam_fn_api_pb2.SampledElement(
element=self._coder_impl.encode_nested(s),
) for s in samples
]
ret.extend(
beam_fn_api_pb2.SampledElement(
element=self._coder_impl.encode_nested(s),
exception=beam_fn_api_pb2.SampledElement.Exception(
instruction_id=exn.instruction_id,
transform_id=exn.transform_id,
error=exn.msg)) for s,
exn in exceptions)
return ret
def sample(self) -> None:
"""Samples the given element to an internal buffer."""
with self._samples_lock:
if self.element_sampler.has_element:
self._samples.append(self.element_sampler.el)
self.element_sampler.has_element = False
def sample_exception(
self, el: Any, exc_info: Any, transform_id: str,
instruction_id: str) -> None:
"""Adds the given exception to the samples."""
with self._samples_lock:
err_string = ''.join(traceback.format_exception(*exc_info))
self._exceptions.append(
(el, ExceptionMetadata(err_string, transform_id, instruction_id)))
class DataSampler:
"""A class for querying any samples generated during execution.
This class is meant to be a singleton with regard to a particular
`sdk_worker.SdkHarness`. When creating the operators, individual
`OutputSampler`s are created from `DataSampler.initialize_samplers`. This
allows for multi-threaded sampling of a PCollection across the SdkHarness.
Samples generated during execution can then be sampled with the `samples`
method. This filters samples from the given pcollection ids.
"""
def __init__(
self,
max_samples: int = 10,
sample_every_sec: float = 30,
sample_only_exceptions: bool = False,
clock=None) -> None:
# Key is PCollection id. Is guarded by the _samplers_lock.
self._samplers: Dict[str, OutputSampler] = {}
# Bundles are processed in parallel, so new samplers may be added when the
# runner queries for samples.
self._samplers_lock: threading.Lock = threading.Lock()
self._max_samples = max_samples
self._sample_every_sec = 0.0 if sample_only_exceptions else sample_every_sec
self._samplers_by_output: Dict[str, List[OutputSampler]] = {}
self._clock = clock
_ENABLE_DATA_SAMPLING = 'enable_data_sampling'
_ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING = 'enable_always_on_exception_sampling'
_DISABLE_ALWAYS_ON_EXCEPTION_SAMPLING = 'disable_always_on_exception_sampling'
@staticmethod
def create(sdk_pipeline_options: PipelineOptions, **kwargs):
experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
# When true, enables only the sampling of exceptions.
always_on_exception_sampling = (
DataSampler._ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING in experiments and
DataSampler._DISABLE_ALWAYS_ON_EXCEPTION_SAMPLING not in experiments)
# When true, enables the sampling of all PCollections and exceptions.
enable_data_sampling = DataSampler._ENABLE_DATA_SAMPLING in experiments
if enable_data_sampling or always_on_exception_sampling:
sample_only_exceptions = (
always_on_exception_sampling and not enable_data_sampling)
return DataSampler(
sample_only_exceptions=sample_only_exceptions, **kwargs)
else:
return None
def stop(self) -> None:
"""Stops all sampling, does not clear samplers in case there are outstanding
samples.
"""
with self._samplers_lock:
for sampler in self._samplers.values():
sampler.stop()
def sampler_for_output(self, transform_id: str,
output_index: int) -> Optional[OutputSampler]:
"""Returns the OutputSampler for the given output."""
try:
with self._samplers_lock:
outputs = self._samplers_by_output[transform_id]
return outputs[output_index]
except KeyError:
_LOGGER.warning(
f'Out-of-bounds access for transform "{transform_id}" ' +
'and output "{output_index}" OutputSampler. This may ' +
'indicate that the transform was improperly ' +
'initialized with the DataSampler.')
return None
def initialize_samplers(
self,
transform_id: str,
descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
coder_factory) -> List[OutputSampler]:
"""Creates the OutputSamplers for the given PTransform.
This initializes the samplers only once per PCollection Id. Note that an
OutputSampler is created per PCollection and an ElementSampler is created
per OutputSampler. This means that multiple ProcessBundles can and will
share the same ElementSampler for a given PCollection.
"""
transform_proto = descriptor.transforms[transform_id]
with self._samplers_lock:
if transform_id in self._samplers_by_output:
return self._samplers_by_output[transform_id]
# Initialize the samplers.
for pcoll_id in transform_proto.outputs.values():
# Only initialize new PCollections.
if pcoll_id in self._samplers:
continue
# Create the sampler with the corresponding coder.
coder_id = descriptor.pcollections[pcoll_id].coder_id
coder = coder_factory(coder_id)
sampler = OutputSampler(
coder, self._max_samples, self._sample_every_sec)
self._samplers[pcoll_id] = sampler
# Next update the lookup table for ElementSamplers for a given PTransform.
# Operations look up the ElementSampler for an output based on the index
# of the tag in the PTransform's outputs. The following code intializes
# the array with ElementSamplers in the correct indices.
outputs = transform_proto.outputs
samplers = [self._samplers[pcoll_id] for pcoll_id in outputs.values()]
self._samplers_by_output[transform_id] = samplers
return samplers
def samples(
self,
pcollection_ids: Optional[Iterable[str]] = None
) -> beam_fn_api_pb2.SampleDataResponse:
"""Returns samples filtered PCollection ids.
All samples from the given PCollections are returned. Empty lists are
wildcards.
"""
ret = beam_fn_api_pb2.SampleDataResponse()
with self._samplers_lock:
samplers = self._samplers.copy()
for pcoll_id in samplers:
if pcollection_ids and pcoll_id not in pcollection_ids:
continue
samples = samplers[pcoll_id].flush()
if samples:
ret.element_samples[pcoll_id].elements.extend(samples)
return ret
def wait_for_samples(
self, pcollection_ids: List[str]) -> beam_fn_api_pb2.SampleDataResponse:
"""Waits for samples to exist for the given PCollections (only testing)."""
now = time.time()
end = now + 30
samples = beam_fn_api_pb2.SampleDataResponse()
while now < end:
time.sleep(0.1)
now = time.time()
samples.MergeFrom(self.samples(pcollection_ids))
if not samples:
continue
has_all = all(
pcoll_id in samples.element_samples for pcoll_id in pcollection_ids)
if has_all:
break
return samples