blob: cd49c69a80aa0d7a895f04a3c96d7fe7dfc033e5 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""SDK Fn Harness entry point."""
# pytype: skip-file
import importlib
import json
import logging
import os
import re
import sys
import traceback
from google.protobuf import text_format # type: ignore # not in typeshed
from apache_beam.internal import pickler
from apache_beam.io import filesystems
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import ProfilingOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.internal import names
from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler
from apache_beam.runners.worker.sdk_worker import SdkHarness
from apache_beam.utils import profiler
_LOGGER = logging.getLogger(__name__)
_ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler'
def _import_beam_plugins(plugins):
for plugin in plugins:
try:
importlib.import_module(plugin)
_LOGGER.debug('Imported beam-plugin %s', plugin)
except ImportError:
try:
_LOGGER.debug((
"Looks like %s is not a module. "
"Trying to import it assuming it's a class"),
plugin)
module, _ = plugin.rsplit('.', 1)
importlib.import_module(module)
_LOGGER.debug('Imported %s for beam-plugin %s', module, plugin)
except ImportError as exc:
_LOGGER.warning('Failed to import beam-plugin %s', plugin, exc_info=exc)
def create_harness(environment, dry_run=False):
"""Creates SDK Fn Harness."""
deferred_exception = None
if 'LOGGING_API_SERVICE_DESCRIPTOR' in environment:
try:
logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
text_format.Merge(
environment['LOGGING_API_SERVICE_DESCRIPTOR'],
logging_service_descriptor)
# Send all logs to the runner.
fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
logging.getLogger().addHandler(fn_log_handler)
_LOGGER.info('Logging handler created.')
except Exception:
_LOGGER.error(
"Failed to set up logging handler, continuing without.",
exc_info=True)
fn_log_handler = None
else:
fn_log_handler = None
pipeline_options_dict = _load_pipeline_options(
environment.get('PIPELINE_OPTIONS'))
default_log_level = _get_log_level_from_options_dict(pipeline_options_dict)
logging.getLogger().setLevel(default_log_level)
_set_log_level_overrides(pipeline_options_dict)
# These are used for dataflow templates.
RuntimeValueProvider.set_runtime_options(pipeline_options_dict)
sdk_pipeline_options = PipelineOptions.from_dictionary(pipeline_options_dict)
filesystems.FileSystems.set_options(sdk_pipeline_options)
pickle_library = sdk_pipeline_options.view_as(SetupOptions).pickle_library
pickler.set_library(pickle_library)
semi_persistent_directory = environment.get('SEMI_PERSISTENT_DIRECTORY', None)
runner_capabilities = frozenset(
environment.get('RUNNER_CAPABILITIES', '').split())
_LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
_worker_id = environment.get('WORKER_ID', None)
if pickle_library != pickler.USE_CLOUDPICKLE:
try:
_load_main_session(semi_persistent_directory)
except LoadMainSessionException:
exception_details = traceback.format_exc()
_LOGGER.error(
'Could not load main session: %s', exception_details, exc_info=True)
raise
except Exception: # pylint: disable=broad-except
summary = (
"Could not load main session. Inspect which external dependencies "
"are used in the main module of your pipeline. Verify that "
"corresponding packages are installed in the pipeline runtime "
"environment and their installed versions match the versions used in "
"pipeline submission environment. For more information, see: https://"
"beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
_LOGGER.error(summary, exc_info=True)
exception_details = traceback.format_exc()
deferred_exception = LoadMainSessionException(
f"{summary} {exception_details}")
_LOGGER.info(
'Pipeline_options: %s',
sdk_pipeline_options.get_all_options(drop_default=True))
control_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
status_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
text_format.Merge(
environment['CONTROL_API_SERVICE_DESCRIPTOR'], control_service_descriptor)
if 'STATUS_API_SERVICE_DESCRIPTOR' in environment:
text_format.Merge(
environment['STATUS_API_SERVICE_DESCRIPTOR'], status_service_descriptor)
# TODO(robertwb): Support authentication.
assert not control_service_descriptor.HasField('authentication')
experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
enable_heap_dump = 'enable_heap_dump' in experiments
beam_plugins = sdk_pipeline_options.view_as(SetupOptions).beam_plugins or []
_import_beam_plugins(beam_plugins)
if dry_run:
return
data_sampler = DataSampler.create(sdk_pipeline_options)
sdk_harness = SdkHarness(
control_address=control_service_descriptor.url,
status_address=status_service_descriptor.url,
worker_id=_worker_id,
state_cache_size=_get_state_cache_size_bytes(
options=sdk_pipeline_options),
data_buffer_time_limit_ms=_get_data_buffer_time_limit_ms(experiments),
profiler_factory=profiler.Profile.factory_from_options(
sdk_pipeline_options.view_as(ProfilingOptions)),
enable_heap_dump=enable_heap_dump,
data_sampler=data_sampler,
deferred_exception=deferred_exception,
runner_capabilities=runner_capabilities)
return fn_log_handler, sdk_harness, sdk_pipeline_options
def _start_profiler(gcp_profiler_service_name, gcp_profiler_service_version):
try:
import googlecloudprofiler
if gcp_profiler_service_name and gcp_profiler_service_version:
googlecloudprofiler.start(
service=gcp_profiler_service_name,
service_version=gcp_profiler_service_version,
verbose=1)
_LOGGER.info('Turning on Google Cloud Profiler.')
else:
raise RuntimeError('Unable to find the job id or job name from envvar.')
except Exception as e: # pylint: disable=broad-except
_LOGGER.warning(
'Unable to start google cloud profiler due to error: %s. For how to '
'enable Cloud Profiler with Dataflow see '
'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
'For troubleshooting tips with Cloud Profiler see '
'https://cloud.google.com/profiler/docs/troubleshooting.' % e)
def _get_gcp_profiler_name_if_enabled(sdk_pipeline_options):
gcp_profiler_service_name = sdk_pipeline_options.view_as(
GoogleCloudOptions).get_cloud_profiler_service_name()
return gcp_profiler_service_name
def main(unused_argv):
"""Main entry point for SDK Fn Harness."""
(fn_log_handler, sdk_harness,
sdk_pipeline_options) = create_harness(os.environ)
gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options)
if gcp_profiler_name:
_start_profiler(gcp_profiler_name, os.environ["JOB_ID"])
try:
_LOGGER.info('Python sdk harness starting.')
sdk_harness.run()
_LOGGER.info('Python sdk harness exiting.')
except: # pylint: disable=broad-except
_LOGGER.critical('Python sdk harness failed: ', exc_info=True)
raise
finally:
if fn_log_handler:
fn_log_handler.close()
def _load_pipeline_options(options_json):
if options_json is None:
return {}
options = json.loads(options_json)
# Check the options field first for backward compatibility.
if 'options' in options:
return options.get('options')
else:
# Remove extra urn part from the key.
portable_option_regex = r'^beam:option:(?P<key>.*):v1$'
return {
re.match(portable_option_regex, k).group('key') if re.match(
portable_option_regex, k) else k: v
for k,
v in options.items()
}
def _parse_pipeline_options(options_json):
return PipelineOptions.from_dictionary(_load_pipeline_options(options_json))
def _get_state_cache_size_bytes(options):
"""Return the maximum size of state cache in bytes.
Returns:
an int indicating the maximum number of bytes to cache.
"""
max_cache_memory_usage_mb = options.view_as(
WorkerOptions).max_cache_memory_usage_mb
# to maintain backward compatibility
experiments = options.view_as(DebugOptions).experiments or []
for experiment in experiments:
# There should only be 1 match so returning from the loop
if re.match(r'state_cache_size=', experiment):
_LOGGER.warning(
'--experiments=state_cache_size=X is deprecated and will be removed '
'in future releases.'
'Please use --max_cache_memory_usage_mb=X to set the cache size for '
'user state API and side inputs.')
return int(
re.match(r'state_cache_size=(?P<state_cache_size>.*)',
experiment).group('state_cache_size')) << 20
return max_cache_memory_usage_mb << 20
def _get_data_buffer_time_limit_ms(experiments):
"""Defines the time limt of the outbound data buffering.
Note: data_buffer_time_limit_ms is an experimental flag and might
not be available in future releases.
Returns:
an int indicating the time limit in milliseconds of the outbound
data buffering. Default is 0 (disabled)
"""
for experiment in experiments:
# There should only be 1 match so returning from the loop
if re.match(r'data_buffer_time_limit_ms=', experiment):
return int(
re.match(
r'data_buffer_time_limit_ms=(?P<data_buffer_time_limit_ms>.*)',
experiment).group('data_buffer_time_limit_ms'))
return 0
def _get_log_level_from_options_dict(options_dict: dict) -> int:
"""Get log level from options dict's entry `default_sdk_harness_log_level`.
If not specified, default log level is logging.INFO.
"""
dict_level = options_dict.get('default_sdk_harness_log_level', 'INFO')
log_level = dict_level
if log_level.isdecimal():
log_level = int(log_level)
else:
# labeled log level
log_level = getattr(logging, log_level, None)
if not isinstance(log_level, int):
# unknown log level.
_LOGGER.error("Unknown log level %s. Use default value INFO.", dict_level)
log_level = logging.INFO
return log_level
def _set_log_level_overrides(options_dict: dict) -> None:
"""Set module log level overrides from options dict's entry
`sdk_harness_log_level_overrides`.
"""
parsed_overrides = options_dict.get('sdk_harness_log_level_overrides', None)
if not isinstance(parsed_overrides, dict):
if parsed_overrides is not None:
_LOGGER.error(
"Unable to parse sdk_harness_log_level_overrides: %s",
parsed_overrides)
return
for module_name, log_level in parsed_overrides.items():
try:
logging.getLogger(module_name).setLevel(log_level)
except Exception as e:
# Never crash the worker when exception occurs during log level setting
# but logging the error.
_LOGGER.error(
"Error occurred when setting log level for %s: %s", module_name, e)
class LoadMainSessionException(Exception):
"""
Used to crash this worker if a main session file failed to load.
"""
pass
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
if semi_persistent_directory:
session_file = os.path.join(
semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
if os.path.isfile(session_file):
# If the expected session file is present but empty, it's likely that
# the user code run by this worker will likely crash at runtime.
# This can happen if the worker fails to download the main session.
# Raise a fatal error and crash this worker, forcing a restart.
if os.path.getsize(session_file) == 0:
# Potenitally transient error, unclear if still happening.
raise LoadMainSessionException(
'Session file found, but empty: %s. Functions defined in __main__ '
'(interactive session) will almost certainly fail.' %
(session_file, ))
pickler.load_session(session_file)
else:
_LOGGER.warning(
'No session file found: %s. Functions defined in __main__ '
'(interactive session) may fail.',
session_file)
else:
_LOGGER.warning(
'No semi_persistent_directory found: Functions defined in __main__ '
'(interactive session) may fail.')
if __name__ == '__main__':
main(sys.argv)