blob: 26e4b60af8c8c8fc30beb968c59b0cc06c8be9ed [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Implementation of ``DataChannel``s to communicate across the data plane."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import logging
import queue
import sys
import threading
from builtins import object
from builtins import range
import grpc
from future.utils import raise_
from future.utils import with_metaclass
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
# This module is experimental. No backwards-compatibility guarantees.
_DEFAULT_FLUSH_THRESHOLD = 10 << 20 # 10MB
class ClosableOutputStream(type(coder_impl.create_OutputStream())):
"""A Outputstream for use with CoderImpls that has a close() method."""
def __init__(self,
close_callback=None,
flush_callback=None,
flush_threshold=_DEFAULT_FLUSH_THRESHOLD):
super(ClosableOutputStream, self).__init__()
self._close_callback = close_callback
self._flush_callback = flush_callback
self._flush_threshold = flush_threshold
# This must be called explicitly to avoid flushing partial elements.
def maybe_flush(self):
if self._flush_callback and self.size() > self._flush_threshold:
self._flush_callback(self.get())
self._clear()
def close(self):
if self._close_callback:
self._close_callback(self.get())
class DataChannel(with_metaclass(abc.ABCMeta, object)):
"""Represents a channel for reading and writing data over the data plane.
Read from this channel with the input_elements method::
for elements_data in data_channel.input_elements(
instruction_id, transform_ids):
[process elements_data]
Write to this channel using the output_stream method::
out1 = data_channel.output_stream(instruction_id, transform_id)
out1.write(...)
out1.close()
When all data for all instructions is written, close the channel::
data_channel.close()
"""
@abc.abstractmethod
def input_elements(
self, instruction_id, expected_transforms, abort_callback=None):
"""Returns an iterable of all Element.Data bundles for instruction_id.
This iterable terminates only once the full set of data has been recieved
for each of the expected transforms. It may block waiting for more data.
Args:
instruction_id: which instruction the results must belong to
expected_transforms: which transforms to wait on for completion
abort_callback: a callback to invoke if blocking returning whether
to abort before consuming all the data
"""
raise NotImplementedError(type(self))
@abc.abstractmethod
def output_stream(self, instruction_id, transform_id):
"""Returns an output stream writing elements to transform_id.
Args:
instruction_id: which instruction this stream belongs to
transform_id: the transform_id of the returned stream
"""
raise NotImplementedError(type(self))
@abc.abstractmethod
def close(self):
"""Closes this channel, indicating that all data has been written.
Data can continue to be read.
If this channel is shared by many instructions, should only be called on
worker shutdown.
"""
raise NotImplementedError(type(self))
class InMemoryDataChannel(DataChannel):
"""An in-memory implementation of a DataChannel.
This channel is two-sided. What is written to one side is read by the other.
The inverse() method returns the other side of a instance.
"""
def __init__(self, inverse=None):
self._inputs = []
self._inverse = inverse or InMemoryDataChannel(self)
def inverse(self):
return self._inverse
def input_elements(self, instruction_id, unused_expected_transforms=None,
abort_callback=None):
other_inputs = []
for data in self._inputs:
if data.instruction_id == instruction_id:
if data.data:
yield data
else:
other_inputs.append(data)
self._inputs = other_inputs
def output_stream(self, instruction_id, transform_id):
def add_to_inverse_output(data):
self._inverse._inputs.append( # pylint: disable=protected-access
beam_fn_api_pb2.Elements.Data(
instruction_id=instruction_id,
transform_id=transform_id,
data=data))
return ClosableOutputStream(
add_to_inverse_output, flush_callback=add_to_inverse_output)
def close(self):
pass
class _GrpcDataChannel(DataChannel):
"""Base class for implementing a BeamFnData-based DataChannel."""
_WRITES_FINISHED = object()
def __init__(self):
self._to_send = queue.Queue()
self._received = collections.defaultdict(lambda: queue.Queue(maxsize=5))
self._receive_lock = threading.Lock()
self._reads_finished = threading.Event()
self._closed = False
self._exc_info = None
def close(self):
self._to_send.put(self._WRITES_FINISHED)
self._closed = True
def wait(self, timeout=None):
self._reads_finished.wait(timeout)
def _receiving_queue(self, instruction_id):
with self._receive_lock:
return self._received[instruction_id]
def _clean_receiving_queue(self, instruction_id):
with self._receive_lock:
self._received.pop(instruction_id)
def input_elements(self, instruction_id, expected_transforms,
abort_callback=None):
"""
Generator to retrieve elements for an instruction_id
input_elements should be called only once for an instruction_id
Args:
instruction_id(str): instruction_id for which data is read
expected_transforms(collection): expected transforms
"""
received = self._receiving_queue(instruction_id)
done_transforms = []
abort_callback = abort_callback or (lambda: False)
try:
while len(done_transforms) < len(expected_transforms):
try:
data = received.get(timeout=1)
except queue.Empty:
if self._closed:
raise RuntimeError('Channel closed prematurely.')
if abort_callback():
return
if self._exc_info:
t, v, tb = self._exc_info
raise_(t, v, tb)
else:
if not data.data and data.transform_id in expected_transforms:
done_transforms.append(data.transform_id)
else:
assert data.transform_id not in done_transforms
yield data
finally:
# Instruction_ids are not reusable so Clean queue once we are done with
# an instruction_id
self._clean_receiving_queue(instruction_id)
def output_stream(self, instruction_id, transform_id):
def add_to_send_queue(data):
if data:
self._to_send.put(
beam_fn_api_pb2.Elements.Data(
instruction_id=instruction_id,
transform_id=transform_id,
data=data))
def close_callback(data):
add_to_send_queue(data)
# End of stream marker.
self._to_send.put(
beam_fn_api_pb2.Elements.Data(
instruction_id=instruction_id,
transform_id=transform_id,
data=b''))
return ClosableOutputStream(
close_callback, flush_callback=add_to_send_queue)
def _write_outputs(self):
done = False
while not done:
data = [self._to_send.get()]
try:
# Coalesce up to 100 other items.
for _ in range(100):
data.append(self._to_send.get_nowait())
except queue.Empty:
pass
if data[-1] is self._WRITES_FINISHED:
done = True
data.pop()
if data:
yield beam_fn_api_pb2.Elements(data=data)
def _read_inputs(self, elements_iterator):
try:
for elements in elements_iterator:
for data in elements.data:
self._receiving_queue(data.instruction_id).put(data)
except: # pylint: disable=bare-except
if not self._closed:
logging.exception('Failed to read inputs in the data plane.')
self._exc_info = sys.exc_info()
raise
finally:
self._closed = True
self._reads_finished.set()
def set_inputs(self, elements_iterator):
reader = threading.Thread(
target=lambda: self._read_inputs(elements_iterator),
name='read_grpc_client_inputs')
reader.daemon = True
reader.start()
class GrpcClientDataChannel(_GrpcDataChannel):
"""A DataChannel wrapping the client side of a BeamFnData connection."""
def __init__(self, data_stub):
super(GrpcClientDataChannel, self).__init__()
self.set_inputs(data_stub.Data(self._write_outputs()))
class BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer):
"""Implementation of BeamFnDataServicer for any number of clients"""
def __init__(self):
self._lock = threading.Lock()
self._connections_by_worker_id = collections.defaultdict(
_GrpcDataChannel)
def get_conn_by_worker_id(self, worker_id):
with self._lock:
return self._connections_by_worker_id[worker_id]
def Data(self, elements_iterator, context):
worker_id = dict(context.invocation_metadata()).get('worker_id')
data_conn = self.get_conn_by_worker_id(worker_id)
data_conn.set_inputs(elements_iterator)
for elements in data_conn._write_outputs():
yield elements
class DataChannelFactory(with_metaclass(abc.ABCMeta, object)):
"""An abstract factory for creating ``DataChannel``."""
@abc.abstractmethod
def create_data_channel(self, remote_grpc_port):
"""Returns a ``DataChannel`` from the given RemoteGrpcPort."""
raise NotImplementedError(type(self))
@abc.abstractmethod
def close(self):
"""Close all channels that this factory owns."""
raise NotImplementedError(type(self))
class GrpcClientDataChannelFactory(DataChannelFactory):
"""A factory for ``GrpcClientDataChannel``.
Caches the created channels by ``data descriptor url``.
"""
def __init__(self, credentials=None, worker_id=None):
self._data_channel_cache = {}
self._lock = threading.Lock()
self._credentials = None
self._worker_id = worker_id
if credentials is not None:
logging.info('Using secure channel creds.')
self._credentials = credentials
def create_data_channel(self, remote_grpc_port):
url = remote_grpc_port.api_service_descriptor.url
if url not in self._data_channel_cache:
with self._lock:
if url not in self._data_channel_cache:
logging.info('Creating client data channel for %s', url)
# Options to have no limits (-1) on the size of the messages
# received or sent over the data plane. The actual buffer size
# is controlled in a layer above.
channel_options = [("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)]
grpc_channel = None
if self._credentials is None:
grpc_channel = GRPCChannelFactory.insecure_channel(
url, options=channel_options)
else:
grpc_channel = GRPCChannelFactory.secure_channel(
url, self._credentials, options=channel_options)
# Add workerId to the grpc channel
grpc_channel = grpc.intercept_channel(
grpc_channel, WorkerIdInterceptor(self._worker_id))
self._data_channel_cache[url] = GrpcClientDataChannel(
beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
return self._data_channel_cache[url]
def close(self):
logging.info('Closing all cached grpc data channels.')
for _, channel in self._data_channel_cache.items():
channel.close()
self._data_channel_cache.clear()
class InMemoryDataChannelFactory(DataChannelFactory):
"""A singleton factory for ``InMemoryDataChannel``."""
def __init__(self, in_memory_data_channel):
self._in_memory_data_channel = in_memory_data_channel
def create_data_channel(self, unused_remote_grpc_port):
return self._in_memory_data_channel
def close(self):
pass