blob: ea09aafba317667fc15372f8a55b9e4f0edd8225 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for handling side inputs."""
# pytype: skip-file
import logging
import queue
import threading
import traceback
from collections import abc
from apache_beam.coders import observable
from apache_beam.io import iobase
from apache_beam.runners.worker import opcounters
from apache_beam.transforms import window
from apache_beam.utils.sentinel import Sentinel
# Maximum number of reader threads for reading side input sources, per side
# input.
MAX_SOURCE_READER_THREADS = 15
# Number of slots for elements in side input element queue. Note that this
# value is intentionally smaller than MAX_SOURCE_READER_THREADS so as to reduce
# memory pressure of holding potentially-large elements in memory. Note that
# the number of pending elements in memory is equal to the sum of
# MAX_SOURCE_READER_THREADS and ELEMENT_QUEUE_SIZE.
ELEMENT_QUEUE_SIZE = 10
# Special element value sentinel for signaling reader state.
READER_THREAD_IS_DONE_SENTINEL = Sentinel.sentinel
# Used to efficiently window the values of non-windowed side inputs.
_globally_windowed = window.GlobalWindows.windowed_value(None).with_value
_LOGGER = logging.getLogger(__name__)
class PrefetchingSourceSetIterable(object):
"""Value iterator that reads concurrently from a set of sources."""
def __init__(
self,
sources,
max_reader_threads=MAX_SOURCE_READER_THREADS,
read_counter=None,
element_counter=None):
self.sources = sources
self.num_reader_threads = min(max_reader_threads, len(self.sources))
# Queue for sources that are to be read.
self.sources_queue = queue.Queue()
for source in sources:
self.sources_queue.put(source)
# Queue for elements that have been read.
self.element_queue = queue.Queue(ELEMENT_QUEUE_SIZE)
# Queue for exceptions encountered in reader threads; to be rethrown.
self.reader_exceptions = queue.Queue()
# Whether we have already iterated; this iterable can only be used once.
self.already_iterated = False
# Whether an error was encountered in any source reader.
self.has_errored = False
self.read_counter = read_counter or opcounters.NoOpTransformIOCounter()
self.element_counter = element_counter
self.reader_threads = []
self._start_reader_threads()
def add_byte_counter(self, reader):
"""Adds byte counter observer to a side input reader.
Args:
reader: A reader that should inherit from ObservableMixin to have
bytes tracked.
"""
def update_bytes_read(record_size, is_record_size=False, **kwargs):
# Let the reader report block size.
if is_record_size:
self.read_counter.add_bytes_read(record_size)
if isinstance(reader, observable.ObservableMixin):
reader.register_observer(update_bytes_read)
def _start_reader_threads(self):
for _ in range(0, self.num_reader_threads):
t = threading.Thread(target=self._reader_thread)
t.daemon = True
t.start()
self.reader_threads.append(t)
def _reader_thread(self):
# pylint: disable=too-many-nested-blocks
try:
while True:
try:
source = self.sources_queue.get_nowait()
if isinstance(source, iobase.BoundedSource):
for value in source.read(source.get_range_tracker(None, None)):
if self.has_errored:
# If any reader has errored, just return.
return
if isinstance(value, window.WindowedValue):
self.element_queue.put(value)
else:
self.element_queue.put(_globally_windowed(value))
else:
# Native dataflow source.
with source.reader() as reader:
# The tracking of time spend reading and bytes read from side
# inputs is kept behind an experiment flag to test performance
# impact.
self.add_byte_counter(reader)
returns_windowed_values = reader.returns_windowed_values
for value in reader:
if self.has_errored:
# If any reader has errored, just return.
return
if returns_windowed_values:
self.element_queue.put(value)
else:
self.element_queue.put(_globally_windowed(value))
except queue.Empty:
return
except Exception as e: # pylint: disable=broad-except
_LOGGER.error(
'Encountered exception in PrefetchingSourceSetIterable '
'reader thread: %s',
traceback.format_exc())
self.reader_exceptions.put(e)
self.has_errored = True
finally:
self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
def __iter__(self):
# pylint: disable=too-many-nested-blocks
if self.already_iterated:
raise RuntimeError(
'Can only iterate once over PrefetchingSourceSetIterable instance.')
self.already_iterated = True
# The invariants during execution are:
# 1) A worker thread always posts the sentinel as the last thing it does
# before exiting.
# 2) We always wait for all sentinels and then join all threads.
num_readers_finished = 0
try:
while True:
try:
with self.read_counter:
element = self.element_queue.get()
if element is READER_THREAD_IS_DONE_SENTINEL:
num_readers_finished += 1
if num_readers_finished == self.num_reader_threads:
return
else:
if self.element_counter:
self.element_counter.update_from(element)
yield element
self.element_counter.update_collect()
else:
yield element
finally:
if self.has_errored:
raise self.reader_exceptions.get()
except GeneratorExit:
self.has_errored = True
raise
finally:
while num_readers_finished < self.num_reader_threads:
element = self.element_queue.get()
if element is READER_THREAD_IS_DONE_SENTINEL:
num_readers_finished += 1
for t in self.reader_threads:
t.join()
def get_iterator_fn_for_sources(
sources,
max_reader_threads=MAX_SOURCE_READER_THREADS,
read_counter=None,
element_counter=None):
"""Returns callable that returns iterator over elements for given sources."""
def _inner():
return iter(
PrefetchingSourceSetIterable(
sources,
max_reader_threads=max_reader_threads,
read_counter=read_counter,
element_counter=element_counter))
return _inner
class EmulatedIterable(abc.Iterable):
"""Emulates an iterable for a side input."""
def __init__(self, iterator_fn):
self.iterator_fn = iterator_fn
def __iter__(self):
return self.iterator_fn()