[BEAM-8151, BEAM-7848] Swap to using a thread pool which is unbounded and shrinks when threads are idle.
diff --git a/sdks/python/apache_beam/runners/portability/artifact_service_test.py b/sdks/python/apache_beam/runners/portability/artifact_service_test.py
index cd9d32b..f5da724 100644
--- a/sdks/python/apache_beam/runners/portability/artifact_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/artifact_service_test.py
@@ -28,13 +28,13 @@
import tempfile
import time
import unittest
-from concurrent import futures
import grpc
from apache_beam.portability.api import beam_artifact_api_pb2
from apache_beam.portability.api import beam_artifact_api_pb2_grpc
from apache_beam.runners.portability import artifact_service
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class AbstractArtifactServiceTest(unittest.TestCase):
@@ -76,7 +76,7 @@
self._run_staging(self._service, self._service)
def test_with_grpc(self):
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ server = grpc.server(UnboundedThreadPoolExecutor())
try:
beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
self._service, server)
@@ -208,7 +208,7 @@
self._service, tokens[session(index)], name(index)))
# pylint: disable=range-builtin-not-iterating
- pool = futures.ThreadPoolExecutor(max_workers=10)
+ pool = UnboundedThreadPoolExecutor()
sessions = set(pool.map(put, range(100)))
tokens = dict(pool.map(commit, sessions))
# List forces materialization.
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 66b0fa3..7876246 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -17,7 +17,6 @@
from __future__ import absolute_import
import argparse
-import concurrent.futures as futures
import logging
import signal
import sys
@@ -30,6 +29,7 @@
from apache_beam.portability.api import beam_expansion_api_pb2_grpc
from apache_beam.runners.portability import expansion_service
from apache_beam.transforms import ptransform
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
# This script provides an expansion service and example ptransforms for running
# external transform test cases. See external_test.py for details.
@@ -163,7 +163,7 @@
help='port on which to serve the job api')
options = parser.parse_args()
global server
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ server = grpc.server(UnboundedThreadPoolExecutor())
beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
expansion_service.ExpansionServiceServicer(PipelineOptions()), server
)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 0735f30..b56c26a 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -33,7 +33,6 @@
import time
import uuid
from builtins import object
-from concurrent import futures
import grpc
@@ -78,6 +77,7 @@
from apache_beam.utils import profiler
from apache_beam.utils import proto_utils
from apache_beam.utils import windowed_value
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
# This module is experimental. No backwards-compatibility guarantees.
@@ -1224,12 +1224,10 @@
_DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5
- def __init__(self, state, provision_info, max_workers):
+ def __init__(self, state, provision_info):
self.state = state
self.provision_info = provision_info
- self.max_workers = max_workers
- self.control_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=self.max_workers))
+ self.control_server = grpc.server(UnboundedThreadPoolExecutor())
self.control_port = self.control_server.add_insecure_port('[::]:0')
self.control_address = 'localhost:%s' % self.control_port
@@ -1239,12 +1237,12 @@
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)]
self.data_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=self.max_workers),
+ UnboundedThreadPoolExecutor(),
options=no_max_message_sizes)
self.data_port = self.data_server.add_insecure_port('[::]:0')
self.state_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=self.max_workers),
+ UnboundedThreadPoolExecutor(),
options=no_max_message_sizes)
self.state_port = self.state_server.add_insecure_port('[::]:0')
@@ -1280,7 +1278,7 @@
self.state_server)
self.logging_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=2),
+ UnboundedThreadPoolExecutor(),
options=no_max_message_sizes)
self.logging_port = self.logging_server.add_insecure_port('[::]:0')
beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
@@ -1508,24 +1506,12 @@
# Any environment will do, pick one arbitrarily.
environment_id = next(iter(self._environments.keys()))
environment = self._environments[environment_id]
- max_total_workers = num_workers * len(self._environments)
# assume all environments except EMBEDDED_PYTHON use gRPC.
if environment.urn == python_urns.EMBEDDED_PYTHON:
pass # no need for a gRPC server
elif self._grpc_server is None:
- self._grpc_server = GrpcServer(self._state, self._job_provision_info,
- max_total_workers)
- elif max_total_workers > self._grpc_server.max_workers:
- # each gRPC server is running with fixed number of threads (
- # max_total_workers), which is defined by the first call to
- # get_worker_handlers(). Assumption here is a worker has a connection to a
- # gRPC server. In case a stage tries to add more workers
- # than the max_total_workers, some workers cannot connect to gRPC and
- # pipeline will hang, hence raise an error here.
- raise RuntimeError('gRPC servers are running with %s threads, we cannot '
- 'attach %s workers.' % (self._grpc_server.max_workers,
- max_total_workers))
+ self._grpc_server = GrpcServer(self._state, self._job_provision_info)
worker_handler_list = self._cached_handlers[environment_id]
if len(worker_handler_list) < num_workers:
@@ -1801,7 +1787,7 @@
merged_result = None
split_result_list = []
- with futures.ThreadPoolExecutor(max_workers=self._num_workers) as executor:
+ with UnboundedThreadPoolExecutor() as executor:
for result, split_result in executor.map(lambda part: BundleManager(
self._worker_handler_list, self._get_buffer,
self._get_input_coder_impl, self._bundle_descriptor,
diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py
index 6aad1af..b8f84ce 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service.py
@@ -26,7 +26,6 @@
import time
import traceback
from builtins import object
-from concurrent import futures
import grpc
from google.protobuf import text_format
@@ -42,6 +41,7 @@
from apache_beam.runners.portability import abstract_job_service
from apache_beam.runners.portability import artifact_service
from apache_beam.runners.portability import fn_api_runner
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class LocalJobServicer(abstract_job_service.AbstractJobServiceServicer):
@@ -92,7 +92,7 @@
self._artifact_staging_endpoint)
def start_grpc_server(self, port=0):
- self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=3))
+ self._server = grpc.server(UnboundedThreadPoolExecutor())
port = self._server.add_insecure_port('localhost:%d' % port)
beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server)
beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
@@ -139,7 +139,7 @@
self._worker_id = worker_id
def run(self):
- logging_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ logging_server = grpc.server(UnboundedThreadPoolExecutor())
logging_port = logging_server.add_insecure_port('[::]:0')
logging_server.start()
logging_servicer = BeamFnLoggingServicer()
diff --git a/sdks/python/apache_beam/runners/portability/portable_stager_test.py b/sdks/python/apache_beam/runners/portability/portable_stager_test.py
index d65c404..fd86819 100644
--- a/sdks/python/apache_beam/runners/portability/portable_stager_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_stager_test.py
@@ -27,13 +27,13 @@
import string
import tempfile
import unittest
-from concurrent import futures
import grpc
from apache_beam.portability.api import beam_artifact_api_pb2
from apache_beam.portability.api import beam_artifact_api_pb2_grpc
from apache_beam.runners.portability import portable_stager
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class PortableStagerTest(unittest.TestCase):
@@ -56,7 +56,7 @@
describing the name of the artifacts in local temp folder and desired
name in staging location.
"""
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ server = grpc.server(UnboundedThreadPoolExecutor())
staging_service = TestLocalFileSystemArtifactStagingServiceServicer(
self._remote_dir)
beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py
index d11390a..900532b 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane_test.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py
@@ -25,7 +25,6 @@
import sys
import threading
import unittest
-from concurrent import futures
import grpc
from future.utils import raise_
@@ -34,6 +33,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
def timeout(timeout_secs):
@@ -67,7 +67,7 @@
data_channel_service = \
data_servicer.get_conn_by_worker_id(worker_id)
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ server = grpc.server(UnboundedThreadPoolExecutor())
beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
data_servicer, server)
test_port = server.add_insecure_port('[::]:0')
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
index ab042aa..6650ccd 100644
--- a/sdks/python/apache_beam/runners/worker/log_handler_test.py
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -20,7 +20,6 @@
import logging
import unittest
from builtins import range
-from concurrent import futures
import grpc
@@ -28,6 +27,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.worker import log_handler
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
@@ -47,7 +47,7 @@
def setUp(self):
self.test_logging_service = BeamFnLoggingServicer()
- self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ self.server = grpc.server(UnboundedThreadPoolExecutor())
beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
self.test_logging_service, self.server)
self.test_port = self.server.add_insecure_port('[::]:0')
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 74a3e99..2cbd196 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -31,7 +31,6 @@
import traceback
from builtins import object
from builtins import range
-from concurrent import futures
import grpc
from future.utils import raise_
@@ -45,6 +44,7 @@
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
# This SDK harness will (by default), log a "lull" in processing if it sees no
# transitions in over 5 minutes.
@@ -97,15 +97,9 @@
# one worker for progress/split request.
self.progress_worker = SdkWorker(self._bundle_processor_cache,
profiler_factory=self._profiler_factory)
- # one thread is enough for getting the progress report.
- # Assumption:
- # Progress report generation should not do IO or wait on other resources.
- # Without wait, having multiple threads will not improve performance and
- # will only add complexity.
- self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+ self._progress_thread_pool = UnboundedThreadPoolExecutor()
# finalize and process share one thread pool.
- self._process_thread_pool = futures.ThreadPoolExecutor(
- max_workers=self._worker_count)
+ self._process_thread_pool = UnboundedThreadPoolExecutor()
self._responses = queue.Queue()
self._process_bundle_queue = queue.Queue()
self._unscheduled_process_bundle = {}
@@ -202,7 +196,7 @@
self._unscheduled_process_bundle[request.instruction_id] = time.time()
self._process_thread_pool.submit(task)
logging.debug(
- "Currently using %s threads." % len(self._process_thread_pool._threads))
+ "Currently using %s threads." % len(self._process_thread_pool._workers))
def _request_process_bundle_split(self, request):
self._request_process_bundle_action(request)
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6cb8ed..ce2bb1f 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -190,7 +190,7 @@
future releases.
Returns:
- an int containing the worker_threads to use. Default is 12
+ an int containing the worker_threads to use. Default is 12.
"""
experiments = pipeline_options.view_as(DebugOptions).experiments
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 71263a8..89047ef 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -23,7 +23,6 @@
import logging
import unittest
from builtins import range
-from concurrent import futures
import grpc
@@ -31,6 +30,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.worker import sdk_worker
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
@@ -93,7 +93,7 @@
test_controller = BeamFnControlServicer(requests)
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ server = grpc.server(UnboundedThreadPoolExecutor())
beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
test_controller, server)
test_port = server.add_insecure_port("[::]:0")
diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py
index 11db233..beacd30 100644
--- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py
+++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py
@@ -35,13 +35,13 @@
import sys
import threading
import time
-from concurrent import futures
import grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker import sdk_worker
+from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
class BeamFnExternalWorkerPoolServicer(
@@ -60,7 +60,7 @@
@classmethod
def start(cls, worker_threads=1, use_process=False, port=0,
state_cache_size=0, container_executable=None):
- worker_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ worker_server = grpc.server(UnboundedThreadPoolExecutor())
worker_address = 'localhost:%s' % worker_server.add_insecure_port(
'[::]:%s' % port)
worker_pool = cls(worker_threads,
diff --git a/sdks/python/apache_beam/utils/thread_pool_executor.py b/sdks/python/apache_beam/utils/thread_pool_executor.py
index 5e324cc..aba8f5ad 100644
--- a/sdks/python/apache_beam/utils/thread_pool_executor.py
+++ b/sdks/python/apache_beam/utils/thread_pool_executor.py
@@ -29,8 +29,8 @@
class _WorkItem(object):
- def __init__(self, f, fn, args, kwargs):
- self._future = f
+ def __init__(self, future, fn, args, kwargs):
+ self._future = future
self._fn = fn
self._fn_args = args
self._fn_kwargs = kwargs
@@ -40,9 +40,16 @@
# If the future wasn't cancelled, then attempt to execute it.
try:
self._future.set_result(self._fn(*self._fn_args, **self._fn_kwargs))
- except:
- e, tb = sys.exc_info()[1:]
- self._future.set_exception_info(e, tb)
+ except BaseException as exc:
+ # Even though Python 2 futures library has #set_exection(),
+ # the way it generates the traceback doesn't align with
+ # the way in which Python 3 does it so we provide alternative
+ # implementations that match our test expectations.
+ if sys.version_info.major >= 3:
+ self._future.set_exception(exc)
+ else:
+ e, tb = sys.exc_info()[1:]
+ self._future.set_exception_info(e, tb)
class _Worker(threading.Thread):
@@ -83,7 +90,8 @@
# around through to the wait() won't block and we will exit
# since _work_item will be unset.
- # We only exit when _work_item is unset to prevent dropping of submitted work.
+ # We only exit when _work_item is unset to prevent dropping of
+ # submitted work.
if self._work_item is None:
self._shutdown = True
return
@@ -125,15 +133,15 @@
A runtime error is raised if the pool has been shutdown.
"""
- f = _base.Future()
- work_item = _WorkItem(f, fn, args, kwargs)
+ future = _base.Future()
+ work_item = _WorkItem(future, fn, args, kwargs)
try:
# Keep trying to get an idle worker from the queue until we find one
# that accepts the work.
while not self._idle_worker_queue.get(
block=False).accepted_work(work_item):
pass
- return f
+ return future
except queue.Empty:
with self._lock:
if self._shutdown:
@@ -146,7 +154,7 @@
worker.daemon = True
worker.start()
self._workers.add(worker)
- return f
+ return future
def shutdown(self, wait=True):
with self._lock:
diff --git a/sdks/python/apache_beam/utils/thread_pool_executor_test.py b/sdks/python/apache_beam/utils/thread_pool_executor_test.py
index cf682be..3616409 100644
--- a/sdks/python/apache_beam/utils/thread_pool_executor_test.py
+++ b/sdks/python/apache_beam/utils/thread_pool_executor_test.py
@@ -21,6 +21,7 @@
import threading
import time
+import traceback
import unittest
# patches unittest.TestCase to be python3 compatible
@@ -39,6 +40,9 @@
self._worker_idents.append(threading.current_thread().ident)
time.sleep(sleep_time)
+ def raise_error(self, message):
+ raise ValueError(message)
+
def test_shutdown_with_no_workers(self):
with UnboundedThreadPoolExecutor():
pass
@@ -83,6 +87,19 @@
self.assertEqual(10, len(self._worker_idents))
self.assertTrue(len(set(self._worker_idents)) < 10)
+ def test_exception_propagation(self):
+ with UnboundedThreadPoolExecutor() as executor:
+ future = executor.submit(self.raise_error, 'footest')
+
+ try:
+ future.result()
+ except Exception:
+ message = traceback.format_exc()
+ else:
+ raise AssertionError('expected exception not raised')
+
+ self.assertIn('footest', message)
+ self.assertIn('raise_error', message)
if __name__ == '__main__':
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index bc77fd8..3660b72 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -187,6 +187,7 @@
# Sphinx cannot find this py:class reference target
'typing.Generic',
+ 'concurrent.futures._base.Executor',
]
# When inferring a base class it will use ':py:class'; if inferring a function