blob: 9c14c25668eba989bfed9c87b0e26572691d16d2 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for the pipeline options module."""
from __future__ import absolute_import
import logging
import unittest
import hamcrest as hc
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
class PipelineOptionsTest(unittest.TestCase):
def tearDown(self):
# Clean up the global variable used by RuntimeValueProvider
RuntimeValueProvider.runtime_options = None
TEST_CASES = [
{'flags': ['--num_workers', '5'],
'expected': {'num_workers': 5,
'mock_flag': False,
'mock_option': None,
'mock_multi_option': None},
'display_data': [DisplayDataItemMatcher('num_workers', 5)]},
{
'flags': [
'--profile_cpu', '--profile_location', 'gs://bucket/', 'ignored'],
'expected': {
'profile_cpu': True, 'profile_location': 'gs://bucket/',
'mock_flag': False, 'mock_option': None,
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('profile_cpu',
True),
DisplayDataItemMatcher('profile_location',
'gs://bucket/')]
},
{'flags': ['--num_workers', '5', '--mock_flag'],
'expected': {'num_workers': 5,
'mock_flag': True,
'mock_option': None,
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('num_workers', 5),
DisplayDataItemMatcher('mock_flag', True)]
},
{'flags': ['--mock_option', 'abc'],
'expected': {'mock_flag': False,
'mock_option': 'abc',
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('mock_option', 'abc')]
},
{'flags': ['--mock_option', ' abc def '],
'expected': {'mock_flag': False,
'mock_option': ' abc def ',
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('mock_option', ' abc def ')]
},
{'flags': ['--mock_option= abc xyz '],
'expected': {'mock_flag': False,
'mock_option': ' abc xyz ',
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('mock_option', ' abc xyz ')]
},
{'flags': ['--mock_option=gs://my bucket/my folder/my file',
'--mock_multi_option=op1',
'--mock_multi_option=op2'],
'expected': {'mock_flag': False,
'mock_option': 'gs://my bucket/my folder/my file',
'mock_multi_option': ['op1', 'op2']},
'display_data': [
DisplayDataItemMatcher(
'mock_option', 'gs://my bucket/my folder/my file'),
DisplayDataItemMatcher('mock_multi_option', ['op1', 'op2'])]
},
{'flags': ['--mock_multi_option=op1', '--mock_multi_option=op2'],
'expected': {'mock_flag': False,
'mock_option': None,
'mock_multi_option': ['op1', 'op2']},
'display_data': [
DisplayDataItemMatcher('mock_multi_option', ['op1', 'op2'])]
},
{'flags': ['--flink_master=testmaster:8081', '--parallelism=42'],
'expected': {'flink_master': 'testmaster:8081',
'parallelism': 42,
'mock_flag': False,
'mock_option': None,
'mock_multi_option': None},
'display_data': [
DisplayDataItemMatcher('flink_master', 'testmaster:8081'),
DisplayDataItemMatcher('parallelism', 42)]
},
]
# Used for testing newly added flags.
class MockOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument('--mock_flag', action='store_true', help='mock flag')
parser.add_argument('--mock_option', help='mock option')
parser.add_argument(
'--mock_multi_option', action='append', help='mock multi option')
parser.add_argument('--option with space', help='mock option with space')
def test_display_data(self):
for case in PipelineOptionsTest.TEST_CASES:
options = PipelineOptions(flags=case['flags'])
dd = DisplayData.create_from(options)
hc.assert_that(dd.items, hc.contains_inanyorder(*case['display_data']))
def test_get_all_options(self):
for case in PipelineOptionsTest.TEST_CASES:
options = PipelineOptions(flags=case['flags'])
self.assertDictContainsSubset(case['expected'], options.get_all_options())
self.assertEqual(options.view_as(
PipelineOptionsTest.MockOptions).mock_flag,
case['expected']['mock_flag'])
self.assertEqual(options.view_as(
PipelineOptionsTest.MockOptions).mock_option,
case['expected']['mock_option'])
self.assertEqual(options.view_as(
PipelineOptionsTest.MockOptions).mock_multi_option,
case['expected']['mock_multi_option'])
def test_from_dictionary(self):
for case in PipelineOptionsTest.TEST_CASES:
options = PipelineOptions(flags=case['flags'])
all_options_dict = options.get_all_options()
options_from_dict = PipelineOptions.from_dictionary(all_options_dict)
self.assertEqual(options_from_dict.view_as(
PipelineOptionsTest.MockOptions).mock_flag,
case['expected']['mock_flag'])
self.assertEqual(options.view_as(
PipelineOptionsTest.MockOptions).mock_option,
case['expected']['mock_option'])
def test_option_with_space(self):
options = PipelineOptions(flags=['--option with space= value with space'])
self.assertEqual(
getattr(options.view_as(PipelineOptionsTest.MockOptions),
'option with space'), ' value with space')
options_from_dict = PipelineOptions.from_dictionary(
options.get_all_options())
self.assertEqual(
getattr(options_from_dict.view_as(PipelineOptionsTest.MockOptions),
'option with space'), ' value with space')
def test_override_options(self):
base_flags = ['--num_workers', '5']
options = PipelineOptions(base_flags)
self.assertEqual(options.get_all_options()['num_workers'], 5)
self.assertEqual(options.get_all_options()['mock_flag'], False)
options.view_as(PipelineOptionsTest.MockOptions).mock_flag = True
self.assertEqual(options.get_all_options()['num_workers'], 5)
self.assertTrue(options.get_all_options()['mock_flag'])
def test_experiments(self):
options = PipelineOptions(['--experiment', 'abc', '--experiment', 'def'])
self.assertEqual(
sorted(options.get_all_options()['experiments']), ['abc', 'def'])
options = PipelineOptions(['--experiments', 'abc', '--experiments', 'def'])
self.assertEqual(
sorted(options.get_all_options()['experiments']), ['abc', 'def'])
options = PipelineOptions(flags=[''])
self.assertEqual(options.get_all_options()['experiments'], None)
def test_extra_package(self):
options = PipelineOptions(['--extra_package', 'abc',
'--extra_packages', 'def',
'--extra_packages', 'ghi'])
self.assertEqual(
sorted(options.get_all_options()['extra_packages']),
['abc', 'def', 'ghi'])
options = PipelineOptions(flags=[''])
self.assertEqual(options.get_all_options()['extra_packages'], None)
def test_dataflow_job_file(self):
options = PipelineOptions(['--dataflow_job_file', 'abc'])
self.assertEqual(options.get_all_options()['dataflow_job_file'], 'abc')
options = PipelineOptions(flags=[''])
self.assertEqual(options.get_all_options()['dataflow_job_file'], None)
def test_template_location(self):
options = PipelineOptions(['--template_location', 'abc'])
self.assertEqual(options.get_all_options()['template_location'], 'abc')
options = PipelineOptions(flags=[''])
self.assertEqual(options.get_all_options()['template_location'], None)
def test_redefine_options(self):
class TestRedefinedOptios(PipelineOptions): # pylint: disable=unused-variable
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument('--redefined_flag', action='store_true')
class TestRedefinedOptios(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument('--redefined_flag', action='store_true')
options = PipelineOptions(['--redefined_flag'])
self.assertTrue(options.get_all_options()['redefined_flag'])
# TODO(BEAM-1319): Require unique names only within a test.
# For now, <file name acronym>_vp_arg<number> will be the convention
# to name value-provider arguments in tests, as opposed to
# <file name acronym>_non_vp_arg<number> for non-value-provider arguments.
# The number will grow per file as tests are added.
def test_value_provider_options(self):
class UserOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
'--pot_vp_arg1',
help='This flag is a value provider')
parser.add_value_provider_argument(
'--pot_vp_arg2',
default=1,
type=int)
parser.add_argument(
'--pot_non_vp_arg1',
default=1,
type=int
)
# Provide values: if not provided, the option becomes of the type runtime vp
options = UserOptions(['--pot_vp_arg1', 'hello'])
self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider)
self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider)
self.assertIsInstance(options.pot_non_vp_arg1, int)
# Values can be overwritten
options = UserOptions(pot_vp_arg1=5,
pot_vp_arg2=StaticValueProvider(value_type=str,
value='bye'),
pot_non_vp_arg1=RuntimeValueProvider(
option_name='foo',
value_type=int,
default_value=10))
self.assertEqual(options.pot_vp_arg1, 5)
self.assertTrue(options.pot_vp_arg2.is_accessible(),
'%s is not accessible' % options.pot_vp_arg2)
self.assertEqual(options.pot_vp_arg2.get(), 'bye')
self.assertFalse(options.pot_non_vp_arg1.is_accessible())
with self.assertRaises(RuntimeError):
options.pot_non_vp_arg1.get()
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()