blob: 5124bb69e6c88c3dd306dd8eb1f9306231c8dc08 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for apache_beam.runners.worker.data_plane."""
# pytype: skip-file
import itertools
import logging
import time
import unittest
import grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.utils import thread_pool_executor
class DataChannelTest(unittest.TestCase):
def test_grpc_data_channel(self):
self._grpc_data_channel_test()
def test_time_based_flush_grpc_data_channel(self):
self._grpc_data_channel_test(True)
def _grpc_data_channel_test(self, time_based_flush=False):
if time_based_flush:
data_servicer = data_plane.BeamFnDataServicer(
data_buffer_time_limit_ms=100)
else:
data_servicer = data_plane.BeamFnDataServicer()
worker_id = 'worker_0'
data_channel_service = \
data_servicer.get_conn_by_worker_id(worker_id)
server = grpc.server(thread_pool_executor.shared_unbounded_instance())
beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(data_servicer, server)
test_port = server.add_insecure_port('[::]:0')
server.start()
grpc_channel = grpc.insecure_channel('localhost:%s' % test_port)
# Add workerId to the grpc channel
grpc_channel = grpc.intercept_channel(
grpc_channel, WorkerIdInterceptor(worker_id))
data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)
if time_based_flush:
data_channel_client = data_plane.GrpcClientDataChannel(
data_channel_stub, data_buffer_time_limit_ms=100)
else:
data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
try:
self._data_channel_test(
data_channel_service, data_channel_client, time_based_flush)
finally:
data_channel_client.close()
data_channel_service.close()
data_channel_client.wait()
data_channel_service.wait()
def test_in_memory_data_channel(self):
channel = data_plane.InMemoryDataChannel()
self._data_channel_test(channel, channel.inverse())
def _data_channel_test(self, server, client, time_based_flush=False):
self._data_channel_test_one_direction(server, client, time_based_flush)
self._data_channel_test_one_direction(client, server, time_based_flush)
def _data_channel_test_one_direction(
self, from_channel, to_channel, time_based_flush):
transform_1 = '1'
transform_2 = '2'
# Single write.
stream01 = from_channel.output_stream('0', transform_1)
stream01.write(b'abc')
if not time_based_flush:
stream01.close()
self.assertEqual(
list(
itertools.islice(to_channel.input_elements('0', [transform_1]), 1)),
[
beam_fn_api_pb2.Elements.Data(
instruction_id='0', transform_id=transform_1, data=b'abc')
])
# Multiple interleaved writes to multiple instructions.
stream11 = from_channel.output_stream('1', transform_1)
stream11.write(b'abc')
stream21 = from_channel.output_stream('2', transform_1)
stream21.write(b'def')
if not time_based_flush:
stream11.close()
self.assertEqual(
list(
itertools.islice(to_channel.input_elements('1', [transform_1]), 1)),
[
beam_fn_api_pb2.Elements.Data(
instruction_id='1', transform_id=transform_1, data=b'abc')
])
if time_based_flush:
# Wait to ensure stream21 is flushed before stream22.
# Because the flush callback is invoked periodically starting from when a
# stream is constructed, there is no guarantee that one stream's callback
# is called before the other.
time.sleep(0.1)
else:
stream21.close()
stream22 = from_channel.output_stream('2', transform_2)
stream22.write(b'ghi')
if not time_based_flush:
stream22.close()
self.assertEqual(
list(
itertools.islice(
to_channel.input_elements('2', [transform_1, transform_2]), 2)),
[
beam_fn_api_pb2.Elements.Data(
instruction_id='2', transform_id=transform_1, data=b'def'),
beam_fn_api_pb2.Elements.Data(
instruction_id='2', transform_id=transform_2, data=b'ghi')
])
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()