blob: 498a07b70e9e6d84aa4a68559c094e2812642861 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for apache_beam.runners.worker.sdk_worker_main."""
# pytype: skip-file
import io
import logging
import os
import unittest
from hamcrest import all_of
from hamcrest import assert_that
from hamcrest import has_entry
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.runners.worker import sdk_worker_main
from apache_beam.runners.worker import worker_status
from apache_beam.utils.plugin import BeamPlugin
class SdkWorkerMainTest(unittest.TestCase):
# Used for testing newly added flags.
class MockOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument('--eam:option:m_option:v', help='mock option')
parser.add_argument('--eam:option:m_option:v1', help='mock option')
parser.add_argument('--beam:option:m_option:v', help='mock option')
parser.add_argument('--m_flag', action='store_true', help='mock flag')
parser.add_argument('--m_option', help='mock option')
parser.add_argument(
'--m_m_option', action='append', help='mock multi option')
def test_status_server(self):
# Wrapping the method to see if it appears in threadump
def wrapped_method_for_test():
threaddump = worker_status.thread_dump()
self.assertRegex(threaddump, '.*wrapped_method_for_test.*')
wrapped_method_for_test()
def test_parse_pipeline_options(self):
assert_that(
sdk_worker_main._parse_pipeline_options(
'{"options": {' + '"m_option": "/tmp/requirements.txt", ' +
'"m_m_option":["beam_fn_api"]' + '}}').get_all_options(),
all_of(
has_entry('m_m_option', ['beam_fn_api']),
has_entry('m_option', '/tmp/requirements.txt')))
assert_that(
sdk_worker_main._parse_pipeline_options(
'{"beam:option:m_option:v1": "/tmp/requirements.txt", ' +
'"beam:option:m_m_option:v1":["beam_fn_api"]}').get_all_options(),
all_of(
has_entry('m_m_option', ['beam_fn_api']),
has_entry('m_option', '/tmp/requirements.txt')))
assert_that(
sdk_worker_main._parse_pipeline_options(
'{"options": {"beam:option:m_option:v":"mock_val"}}').
get_all_options(),
has_entry('beam:option:m_option:v', 'mock_val'))
assert_that(
sdk_worker_main._parse_pipeline_options(
'{"options": {"eam:option:m_option:v1":"mock_val"}}').
get_all_options(),
has_entry('eam:option:m_option:v1', 'mock_val'))
assert_that(
sdk_worker_main._parse_pipeline_options(
'{"options": {"eam:option:m_option:v":"mock_val"}}').
get_all_options(),
has_entry('eam:option:m_option:v', 'mock_val'))
def test_runtime_values(self):
test_runtime_provider = RuntimeValueProvider('test_param', int, None)
sdk_worker_main.create_harness({
'CONTROL_API_SERVICE_DESCRIPTOR': '',
'PIPELINE_OPTIONS': '{"test_param": 37}',
},
dry_run=True)
self.assertTrue(test_runtime_provider.is_accessible())
self.assertEqual(test_runtime_provider.get(), 37)
def test_create_sdk_harness_log_handler_received_log(self):
# tests that the log handler created in create_harness() does not miss
# logs emitted from create_harness() itself.
logstream = io.StringIO()
class InMemoryHandler(logging.StreamHandler):
def __init__(self, *unused):
super().__init__(stream=logstream)
with unittest.mock.patch(
'apache_beam.runners.worker.sdk_worker_main.FnApiLogRecordHandler',
InMemoryHandler):
sdk_worker_main.create_harness({
'LOGGING_API_SERVICE_DESCRIPTOR': '',
'CONTROL_API_SERVICE_DESCRIPTOR': '',
'PIPELINE_OPTIONS': '{"default_sdk_harness_log_level":"INVALID",'
'"sdk_harness_log_level_overrides":[]}',
},
dry_run=True)
logstream.seek(0)
logs = logstream.read()
self.assertIn('Unknown log level', logs)
self.assertIn('Unable to parse sdk_harness_log_level_overrides', logs)
def test_import_beam_plugins(self):
sdk_worker_main._import_beam_plugins(BeamPlugin.get_all_plugin_paths())
@staticmethod
def _overrides_case_to_option_dict(case):
"""
Return logging level overrides from command line strings via PipelineOption.
"""
options_list = []
for c in case:
options_list += ['--sdk_harness_log_level_overrides', c]
options = PipelineOptions(options_list)
return options.get_all_options()
def test__get_log_level_from_options_dict(self):
test_cases = [
{},
{
'default_sdk_harness_log_level': 'DEBUG'
},
{
'default_sdk_harness_log_level': '30'
},
{
'default_sdk_harness_log_level': 'INVALID_ENTRY'
},
]
expected_results = [logging.INFO, logging.DEBUG, 30, logging.INFO]
for case, expected in zip(test_cases, expected_results):
self.assertEqual(
sdk_worker_main._get_log_level_from_options_dict(case), expected)
def test__set_log_level_overrides(self):
test_cases = [
([], {}), # not provided, as a smoke test
(
# single overrides
['{"fake_module_1a.b":"DEBUG","fake_module_1c.d":"INFO"}'],
{
"fake_module_1a.b": logging.DEBUG,
"fake_module_1a.b.f": logging.DEBUG,
"fake_module_1c.d": logging.INFO
}
),
(
# multiple overrides, the last takes precedence
[
'{"fake_module_2a.b":"DEBUG"}',
'{"fake_module_2c.d":"WARNING","fake_module_2c.d.e":15}',
'{"fake_module_2c.d":"ERROR"}'
],
{
"fake_module_2a.b": logging.DEBUG,
"fake_module_2a.b.f": logging.DEBUG,
"fake_module_2c.d": logging.ERROR,
"fake_module_2c.d.e": 15,
"fake_module_2c.d.f": logging.ERROR
}
)
]
for case, expected in test_cases:
overrides = self._overrides_case_to_option_dict(case)
sdk_worker_main._set_log_level_overrides(overrides)
for name, level in expected.items():
self.assertEqual(logging.getLogger(name).getEffectiveLevel(), level)
def test__set_log_level_overrides_error(self):
test_cases = [
(['{"json.value.is.not.level": ["ERROR"]}'],
"Error occurred when setting log level"),
(['{"invalid.level":"INVALID"}'],
"Error occurred when setting log level"),
]
for case, expected in test_cases:
overrides = self._overrides_case_to_option_dict(case)
with self.assertLogs('apache_beam.runners.worker.sdk_worker_main',
level='ERROR') as cm:
sdk_worker_main._set_log_level_overrides(overrides)
self.assertIn(expected, cm.output[0])
def test_gcp_profiler_uses_provided_service_name_when_specified(self):
options = PipelineOptions(
['--dataflow_service_options=enable_google_cloud_profiler=sample'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample", "version")
@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_gcp_profiler_uses_job_name_when_service_name_not_specified(self):
options = PipelineOptions(
['--dataflow_service_options=enable_google_cloud_profiler'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")
@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_gcp_profiler_uses_job_name_when_enabled_as_experiment(self):
options = PipelineOptions(['--experiment=enable_google_cloud_profiler'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")
@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_pipeline_option_max_cache_memory_usage_mb(self):
options = PipelineOptions(flags=['--max_cache_memory_usage_mb=50'])
cache_size = sdk_worker_main._get_state_cache_size_bytes(options)
self.assertEqual(cache_size, 50 << 20)
@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_pipeline_option_max_cache_memory_usage_mb_with_experiments(self):
options = PipelineOptions(flags=['--experiments=state_cache_size=50'])
cache_size = sdk_worker_main._get_state_cache_size_bytes(options)
self.assertEqual(cache_size, 50 << 20)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()