[BEAM-8151] Swap to create SdkWorkers on demand when processing jobs
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index b56c26a..e479708 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -1383,17 +1383,15 @@
super(EmbeddedGrpcWorkerHandler, self).__init__(state, provision_info,
grpc_server)
if payload:
- num_workers, state_cache_size = payload.decode('ascii').split(',')
- self._num_threads = int(num_workers)
+ state_cache_size = payload.decode('ascii')
self._state_cache_size = int(state_cache_size)
else:
- self._num_threads = 1
self._state_cache_size = STATE_CACHE_SIZE
def start_worker(self):
self.worker = sdk_worker.SdkHarness(
- self.control_address, worker_count=self._num_threads,
- state_cache_size=self._state_cache_size, worker_id=self.worker_id)
+ self.control_address, state_cache_size=self._state_cache_size,
+ worker_id=self.worker_id)
self.worker_thread = threading.Thread(
name='run_worker', target=self.worker.run)
self.worker_thread.daemon = True
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 846fe58..e31e0a5 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -1138,23 +1138,13 @@
default_environment=environments.EmbeddedPythonGrpcEnvironment()))
-class FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest):
-
- def create_pipeline(self):
- return beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(
- default_environment=environments.EmbeddedPythonGrpcEnvironment(
- num_workers=2,
- state_cache_size=fn_api_runner.STATE_CACHE_SIZE)))
-
-
class FnApiRunnerTestWithDisabledCaching(FnApiRunnerTest):
def create_pipeline(self):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment(
- num_workers=2, state_cache_size=0)))
+ state_cache_size=0)))
class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 8b078a5..c3e4176 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -139,7 +139,6 @@
'use_loopback_process_worker', False)
portable_options.environment_config, server = (
worker_pool_main.BeamFnExternalWorkerPoolServicer.start(
- sdk_worker_main._get_worker_count(options),
state_cache_size=sdk_worker_main._get_state_cache_size(options),
use_process=use_loopback_process_worker))
cleanup_callbacks = [functools.partial(server.stop, 1)]
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 2cbd196..488f505 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -27,10 +27,8 @@
import queue
import sys
import threading
-import time
import traceback
from builtins import object
-from builtins import range
import grpc
from future.utils import raise_
@@ -54,17 +52,15 @@
class SdkHarness(object):
REQUEST_METHOD_PREFIX = '_request_'
- SCHEDULING_DELAY_THRESHOLD_SEC = 5*60 # 5 Minutes
def __init__(
- self, control_address, worker_count,
+ self, control_address,
credentials=None,
worker_id=None,
# Caching is disabled by default
state_cache_size=0,
profiler_factory=None):
self._alive = True
- self._worker_count = worker_count
self._worker_index = 0
self._worker_id = worker_id
self._state_cache = StateCache(state_cache_size)
@@ -94,37 +90,14 @@
fns=self._fns)
# workers for process/finalize bundle.
self.workers = queue.Queue()
- # one worker for progress/split request.
- self.progress_worker = SdkWorker(self._bundle_processor_cache,
- profiler_factory=self._profiler_factory)
- self._progress_thread_pool = UnboundedThreadPoolExecutor()
- # finalize and process share one thread pool.
- self._process_thread_pool = UnboundedThreadPoolExecutor()
+ self._worker_thread_pool = UnboundedThreadPoolExecutor()
self._responses = queue.Queue()
- self._process_bundle_queue = queue.Queue()
- self._unscheduled_process_bundle = {}
- logging.info('Initializing SDKHarness with %s workers.', self._worker_count)
+ logging.info('Initializing SDKHarness with unbounded number of workers.')
def run(self):
control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
no_more_work = object()
- # Create process workers
- for _ in range(self._worker_count):
- # SdkHarness manage function registration and share self._fns with all
- # the workers. This is needed because function registration (register)
- # and execution (process_bundle) are send over different request and we
- # do not really know which worker is going to process bundle
- # for a function till we get process_bundle request. Moreover
- # same function is reused by different process bundle calls and
- # potentially get executed by different worker. Hence we need a
- # centralized function list shared among all the workers.
- self.workers.put(
- SdkWorker(self._bundle_processor_cache,
- state_cache_metrics_fn=
- self._state_cache.get_monitoring_infos,
- profiler_factory=self._profiler_factory))
-
def get_responses():
while True:
response = self._responses.get()
@@ -133,10 +106,6 @@
yield response
self._alive = True
- monitoring_thread = threading.Thread(name='SdkHarness_monitor',
- target=self._monitor_process_bundle)
- monitoring_thread.daemon = True
- monitoring_thread.start()
try:
for work_request in control_stub.Control(get_responses()):
@@ -152,8 +121,7 @@
logging.info('No more requests from control plane')
logging.info('SDK Harness waiting for in-flight requests to complete')
# Wait until existing requests are processed.
- self._progress_thread_pool.shutdown()
- self._process_thread_pool.shutdown()
+ self._worker_thread_pool.shutdown()
# get_responses may be blocked on responses.get(), but we need to return
# control to its caller.
self._responses.put(no_more_work)
@@ -181,22 +149,15 @@
def _request_process_bundle(self, request):
def task():
- # Take the free worker. Wait till a worker is free.
- worker = self.workers.get()
- # Get the first work item in the queue
- work = self._process_bundle_queue.get()
- self._unscheduled_process_bundle.pop(work.instruction_id, None)
+ worker = self._get_or_create_worker()
try:
- self._execute(lambda: worker.do_instruction(work), work)
+ self._execute(lambda: worker.do_instruction(request), request)
finally:
# Put the worker back in the free worker pool
self.workers.put(worker)
- # Create a task for each process_bundle request and schedule it
- self._process_bundle_queue.put(request)
- self._unscheduled_process_bundle[request.instruction_id] = time.time()
- self._process_thread_pool.submit(task)
+ self._worker_thread_pool.submit(task)
logging.debug(
- "Currently using %s threads." % len(self._process_thread_pool._workers))
+ "Currently using %s threads." % len(self._worker_thread_pool._workers))
def _request_process_bundle_split(self, request):
self._request_process_bundle_action(request)
@@ -212,17 +173,19 @@
# only process progress/split request when a bundle is in processing.
if (instruction_id in
self._bundle_processor_cache.active_bundle_processors):
- self._execute(
- lambda: self.progress_worker.do_instruction(request), request)
+ worker = self._get_or_create_worker()
+ try:
+ self._execute(lambda: worker.do_instruction(request), request)
+ finally:
+ # Put the worker back in the free worker pool
+ self.workers.put(worker)
else:
self._execute(lambda: beam_fn_api_pb2.InstructionResponse(
instruction_id=request.instruction_id, error=(
- 'Process bundle request not yet scheduled for instruction {}' if
- instruction_id in self._unscheduled_process_bundle else
'Unknown process bundle instruction {}').format(
instruction_id)), request)
- self._progress_thread_pool.submit(task)
+ self._worker_thread_pool.submit(task)
def _request_finalize_bundle(self, request):
self._request_execute(request)
@@ -231,37 +194,23 @@
def task():
# Get one available worker.
- worker = self.workers.get()
+ worker = self._get_or_create_worker()
try:
- self._execute(
- lambda: worker.do_instruction(request), request)
+ self._execute(lambda: worker.do_instruction(request), request)
finally:
# Put the worker back in the free worker pool.
self.workers.put(worker)
- self._process_thread_pool.submit(task)
+ self._worker_thread_pool.submit(task)
- def _monitor_process_bundle(self):
- """
- Monitor the unscheduled bundles and log if a bundle is not scheduled for
- more than SCHEDULING_DELAY_THRESHOLD_SEC.
- """
- while self._alive:
- time.sleep(SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC)
- # Check for bundles to be scheduled.
- if self._unscheduled_process_bundle:
- current_time = time.time()
- for instruction_id in self._unscheduled_process_bundle:
- request_time = None
- try:
- request_time = self._unscheduled_process_bundle[instruction_id]
- except KeyError:
- pass
- if request_time:
- scheduling_delay = current_time - request_time
- if scheduling_delay > SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC:
- logging.warning('Unable to schedule instruction %s for %s',
- instruction_id, scheduling_delay)
+ def _get_or_create_worker(self):
+ try:
+ return self.workers.get_nowait()
+ except queue.Empty:
+ return SdkWorker(self._bundle_processor_cache,
+ state_cache_metrics_fn=
+ self._state_cache.get_monitoring_infos,
+ profiler_factory=self._profiler_factory)
class BundleProcessorCache(object):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index ce2bb1f..2467965 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -147,7 +147,6 @@
assert not service_descriptor.oauth2_client_credentials_grant.url
SdkHarness(
control_address=service_descriptor.url,
- worker_count=_get_worker_count(sdk_pipeline_options),
worker_id=_worker_id,
state_cache_size=_get_state_cache_size(sdk_pipeline_options),
profiler_factory=profiler.Profile.factory_from_options(
@@ -177,35 +176,6 @@
})
-def _get_worker_count(pipeline_options):
- """Extract worker count from the pipeline_options.
-
- This defines how many SdkWorkers will be started in this Python process.
- And each SdkWorker will have its own thread to process data. Name of the
- experimental parameter is 'worker_threads'
- Example Usage in the Command Line:
- --experimental worker_threads=1
-
- Note: worker_threads is an experimental flag and might not be available in
- future releases.
-
- Returns:
- an int containing the worker_threads to use. Default is 12.
- """
- experiments = pipeline_options.view_as(DebugOptions).experiments
-
- experiments = experiments if experiments else []
-
- for experiment in experiments:
- # There should only be 1 match so returning from the loop
- if re.match(r'worker_threads=', experiment):
- return int(
- re.match(r'worker_threads=(?P<worker_threads>.*)',
- experiment).group('worker_threads'))
-
- return 12
-
-
def _get_state_cache_size(pipeline_options):
"""Defines the upper number of state items to cache.
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
index 9703515..cae65a2 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
@@ -20,7 +20,6 @@
from __future__ import division
from __future__ import print_function
-import json
import logging
import unittest
@@ -56,40 +55,24 @@
wrapped_method_for_test()
- def test_work_count_default_value(self):
- self._check_worker_count('{}', 12)
-
def test_parse_pipeline_options(self):
expected_options = PipelineOptions([])
expected_options.view_as(
- SdkWorkerMainTest.MockOptions).m_m_option = [
- 'worker_threads=1', 'beam_fn_api'
- ]
+ SdkWorkerMainTest.MockOptions).m_m_option = ['beam_fn_api']
expected_options.view_as(
SdkWorkerMainTest.MockOptions).m_option = '/tmp/requirements.txt'
self.assertEqual(
- {'m_m_option': ['worker_threads=1']},
- sdk_worker_main._parse_pipeline_options(
- '{"options": {"m_m_option":["worker_threads=1"]}}')
- .get_all_options(drop_default=True))
- self.assertEqual(
expected_options.get_all_options(),
sdk_worker_main._parse_pipeline_options(
'{"options": {' +
'"m_option": "/tmp/requirements.txt", ' +
- '"m_m_option":["worker_threads=1", "beam_fn_api"]' +
+ '"m_m_option":["beam_fn_api"]' +
'}}').get_all_options())
self.assertEqual(
- {'m_m_option': ['worker_threads=1']},
- sdk_worker_main._parse_pipeline_options(
- '{"beam:option:m_m_option:v1":["worker_threads=1"]}')
- .get_all_options(drop_default=True))
- self.assertEqual(
expected_options.get_all_options(),
sdk_worker_main._parse_pipeline_options(
'{"beam:option:m_option:v1": "/tmp/requirements.txt", ' +
- '"beam:option:m_m_option:v1":["worker_threads=1", ' +
- '"beam_fn_api"]}').get_all_options())
+ '"beam:option:m_m_option:v1":["beam_fn_api"]}').get_all_options())
self.assertEqual(
{'beam:option:m_option:v': 'mock_val'},
sdk_worker_main._parse_pipeline_options(
@@ -106,30 +89,6 @@
'{"options": {"eam:option:m_option:v":"mock_val"}}')
.get_all_options(drop_default=True))
- def test_work_count_custom_value(self):
- self._check_worker_count('{"experiments":["worker_threads=1"]}', 1)
- self._check_worker_count('{"experiments":["worker_threads=4"]}', 4)
- self._check_worker_count('{"experiments":["worker_threads=12"]}', 12)
-
- def test_work_count_wrong_format(self):
- self._check_worker_count(
- '{"experiments":["worker_threads="]}', exception=True)
- self._check_worker_count(
- '{"experiments":["worker_threads=a"]}', exception=True)
- self._check_worker_count(
- '{"experiments":["worker_threads=1a"]}', exception=True)
-
- def _check_worker_count(self, pipeline_options, expected=0, exception=False):
- if exception:
- self.assertRaises(
- Exception, sdk_worker_main._get_worker_count,
- PipelineOptions.from_dictionary(json.loads(pipeline_options)))
- else:
- self.assertEqual(
- sdk_worker_main._get_worker_count(
- PipelineOptions.from_dictionary(json.loads(pipeline_options))),
- expected)
-
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 89047ef..a422851 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -78,7 +78,7 @@
tuple of request_count, number of process_bundles per request and workers
counts to process the request.
"""
- for (request_count, process_bundles_per_request, worker_count) in args:
+ for (request_count, process_bundles_per_request) in args:
requests = []
process_bundle_descriptors = []
@@ -100,8 +100,7 @@
server.start()
harness = sdk_worker.SdkHarness(
- "localhost:%s" % test_port, worker_count=worker_count,
- state_cache_size=100)
+ "localhost:%s" % test_port, state_cache_size=100)
harness.run()
for worker in harness.workers.queue:
@@ -110,7 +109,7 @@
for item in process_bundle_descriptors})
def test_fn_registration(self):
- self._check_fn_registration_multi_request((1, 4, 1), (4, 4, 1), (4, 4, 2))
+ self._check_fn_registration_multi_request((1, 4), (4, 4))
if __name__ == "__main__":
diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py
index beacd30..6b1437c 100644
--- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py
+++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py
@@ -47,24 +47,22 @@
class BeamFnExternalWorkerPoolServicer(
beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer):
- def __init__(self, worker_threads,
+ def __init__(self,
use_process=False,
container_executable=None,
state_cache_size=0):
- self._worker_threads = worker_threads
self._use_process = use_process
self._container_executable = container_executable
self._state_cache_size = state_cache_size
self._worker_processes = {}
@classmethod
- def start(cls, worker_threads=1, use_process=False, port=0,
+ def start(cls, use_process=False, port=0,
state_cache_size=0, container_executable=None):
worker_server = grpc.server(UnboundedThreadPoolExecutor())
worker_address = 'localhost:%s' % worker_server.add_insecure_port(
'[::]:%s' % port)
- worker_pool = cls(worker_threads,
- use_process=use_process,
+ worker_pool = cls(use_process=use_process,
container_executable=container_executable,
state_cache_size=state_cache_size)
beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
@@ -88,13 +86,11 @@
'import SdkHarness; '
'SdkHarness('
'"%s",'
- 'worker_count=%d,'
'worker_id="%s",'
'state_cache_size=%d'
')'
'.run()' % (
start_worker_request.control_endpoint.url,
- self._worker_threads,
start_worker_request.worker_id,
self._state_cache_size)]
if self._container_executable:
@@ -120,7 +116,6 @@
else:
worker = sdk_worker.SdkHarness(
start_worker_request.control_endpoint.url,
- worker_count=self._worker_threads,
worker_id=start_worker_request.worker_id,
state_cache_size=self._state_cache_size)
worker_thread = threading.Thread(
@@ -157,11 +152,6 @@
"""Entry point for worker pool service for external environments."""
parser = argparse.ArgumentParser()
- parser.add_argument('--threads_per_worker',
- type=int,
- default=argparse.SUPPRESS,
- dest='worker_threads',
- help='Number of threads per SDK worker.')
parser.add_argument('--container_executable',
type=str,
default=None,
diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py
index 999647f..6f67266 100644
--- a/sdks/python/apache_beam/transforms/environments.py
+++ b/sdks/python/apache_beam/transforms/environments.py
@@ -308,13 +308,11 @@
@Environment.register_urn(python_urns.EMBEDDED_PYTHON_GRPC, bytes)
class EmbeddedPythonGrpcEnvironment(Environment):
- def __init__(self, num_workers=None, state_cache_size=None):
- self.num_workers = num_workers
+ def __init__(self, state_cache_size=None):
self.state_cache_size = state_cache_size
def __eq__(self, other):
return self.__class__ == other.__class__ \
- and self.num_workers == other.num_workers \
and self.state_cache_size == other.state_cache_size
def __ne__(self, other):
@@ -322,34 +320,26 @@
return not self == other
def __hash__(self):
- return hash((self.__class__, self.num_workers, self.state_cache_size))
+ return hash((self.__class__, self.state_cache_size))
def __repr__(self):
repr_parts = []
- if not self.num_workers is None:
- repr_parts.append('num_workers=%d' % self.num_workers)
if not self.state_cache_size is None:
repr_parts.append('state_cache_size=%d' % self.state_cache_size)
return 'EmbeddedPythonGrpcEnvironment(%s)' % ','.join(repr_parts)
def to_runner_api_parameter(self, context):
- if self.num_workers is None and self.state_cache_size is None:
+ if self.state_cache_size is None:
payload = b''
- elif self.num_workers is not None and self.state_cache_size is not None:
- payload = b'%d,%d' % (self.num_workers, self.state_cache_size)
else:
- # We want to make sure that the environment stays the same through the
- # roundtrip to runner api, so here we don't want to set default for the
- # other if only one of num workers or state cache size is set
- raise ValueError('Must provide worker num and state cache size.')
+ payload = b'%d' % self.state_cache_size
return python_urns.EMBEDDED_PYTHON_GRPC, payload
@staticmethod
def from_runner_api_parameter(payload, context):
if payload:
- num_workers, state_cache_size = payload.decode('utf-8').split(',')
+ state_cache_size = payload.decode('utf-8')
return EmbeddedPythonGrpcEnvironment(
- num_workers=int(num_workers),
state_cache_size=int(state_cache_size))
else:
return EmbeddedPythonGrpcEnvironment()
@@ -357,8 +347,8 @@
@classmethod
def from_options(cls, options):
if options.environment_config:
- num_workers, state_cache_size = options.environment_config.split(',')
- return cls(num_workers=num_workers, state_cache_size=state_cache_size)
+ state_cache_size = options.environment_config
+ return cls(state_cache_size=state_cache_size)
else:
return cls()
diff --git a/sdks/python/apache_beam/transforms/environments_test.py b/sdks/python/apache_beam/transforms/environments_test.py
index 0fd568c..46868e8 100644
--- a/sdks/python/apache_beam/transforms/environments_test.py
+++ b/sdks/python/apache_beam/transforms/environments_test.py
@@ -46,7 +46,7 @@
ExternalEnvironment('localhost:8080', params={'k1': 'v1'}),
EmbeddedPythonEnvironment(),
EmbeddedPythonGrpcEnvironment(),
- EmbeddedPythonGrpcEnvironment(num_workers=2, state_cache_size=0),
+ EmbeddedPythonGrpcEnvironment(state_cache_size=0),
SubprocessSDKEnvironment(command_string=u'foö')):
context = pipeline_context.PipelineContext()
self.assertEqual(
@@ -55,13 +55,6 @@
environment.to_runner_api(context), context)
)
- with self.assertRaises(ValueError) as ctx:
- EmbeddedPythonGrpcEnvironment(num_workers=2).to_runner_api(
- pipeline_context.PipelineContext()
- )
- self.assertIn('Must provide worker num and state cache size.',
- ctx.exception.args)
-
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/utils/thread_pool_executor_test.py b/sdks/python/apache_beam/utils/thread_pool_executor_test.py
index 3616409..c82d0f9 100644
--- a/sdks/python/apache_beam/utils/thread_pool_executor_test.py
+++ b/sdks/python/apache_beam/utils/thread_pool_executor_test.py
@@ -19,6 +19,7 @@
from __future__ import absolute_import
+import itertools
import threading
import time
import traceback
@@ -101,6 +102,13 @@
self.assertIn('footest', message)
self.assertIn('raise_error', message)
+ def test_map(self):
+ with UnboundedThreadPoolExecutor() as executor:
+ executor.map(self.append_and_sleep, itertools.repeat(0.01, 5))
+
+ with self._lock:
+ self.assertEqual(5, len(self._worker_idents))
+
if __name__ == '__main__':
unittest.main()