blob: 9818422e2f6949abdad2610374a2dd8d8787ce62 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for the pipeline options validator module."""
from __future__ import absolute_import
import logging
import unittest
from builtins import object
from hamcrest.core.base_matcher import BaseMatcher
from apache_beam.internal import pickler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator
# Mock runners to use for validations.
class MockRunners(object):
class DataflowRunner(object):
pass
class TestDataflowRunner(object):
pass
class OtherRunner(object):
pass
# Matcher that always passes for testing on_success_matcher option
class AlwaysPassMatcher(BaseMatcher):
def _matches(self, item):
return True
class SetupTest(unittest.TestCase):
def check_errors_for_arguments(self, errors, args):
"""Checks that there is exactly one error for each given argument."""
missing = []
remaining = list(errors)
for arg in args:
found = False
for error in remaining:
if arg in error:
remaining.remove(error)
found = True
break
if not found:
missing.append('Missing error for: ' + arg)
# Return missing and remaining (not matched) errors.
return missing + remaining
def test_local_runner(self):
runner = MockRunners.OtherRunner()
options = PipelineOptions([])
validator = PipelineOptionsValidator(options, runner)
errors = validator.validate()
self.assertEqual(len(errors), 0)
def test_missing_required_options(self):
options = PipelineOptions([''])
runner = MockRunners.DataflowRunner()
validator = PipelineOptionsValidator(options, runner)
errors = validator.validate()
self.assertEqual(
self.check_errors_for_arguments(
errors,
['project', 'staging_location', 'temp_location']),
[])
def test_gcs_path(self):
def get_validator(temp_location, staging_location):
options = ['--project=example:example', '--job_name=job']
if temp_location is not None:
options.append('--temp_location=' + temp_location)
if staging_location is not None:
options.append('--staging_location=' + staging_location)
pipeline_options = PipelineOptions(options)
runner = MockRunners.DataflowRunner()
validator = PipelineOptionsValidator(pipeline_options, runner)
return validator
test_cases = [
{'temp_location': None,
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': None,
'staging_location': None,
'errors': ['staging_location', 'temp_location']},
{'temp_location': 'gs://foo/bar',
'staging_location': None,
'errors': []},
{'temp_location': 'gs://foo/bar',
'staging_location': 'gs://ABC/bar',
'errors': ['staging_location']},
{'temp_location': 'gcs:/foo/bar',
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': 'gs:/foo/bar',
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': 'gs://ABC/bar',
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': 'gs://ABC/bar',
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': 'gs://foo',
'staging_location': 'gs://foo/bar',
'errors': ['temp_location']},
{'temp_location': 'gs://foo/',
'staging_location': 'gs://foo/bar',
'errors': []},
{'temp_location': 'gs://foo/bar',
'staging_location': 'gs://foo/bar',
'errors': []},
]
for case in test_cases:
errors = get_validator(case['temp_location'],
case['staging_location']).validate()
self.assertEqual(
self.check_errors_for_arguments(errors, case['errors']), [])
def test_project(self):
def get_validator(project):
options = ['--job_name=job', '--staging_location=gs://foo/bar',
'--temp_location=gs://foo/bar']
if project is not None:
options.append('--project=' + project)
pipeline_options = PipelineOptions(options)
runner = MockRunners.DataflowRunner()
validator = PipelineOptionsValidator(pipeline_options, runner)
return validator
test_cases = [
{'project': None, 'errors': ['project']},
{'project': '12345', 'errors': ['project']},
{'project': 'FOO', 'errors': ['project']},
{'project': 'foo:BAR', 'errors': ['project']},
{'project': 'fo', 'errors': ['project']},
{'project': 'foo', 'errors': []},
{'project': 'foo:bar', 'errors': []},
]
for case in test_cases:
errors = get_validator(case['project']).validate()
self.assertEqual(
self.check_errors_for_arguments(errors, case['errors']), [])
def test_job_name(self):
def get_validator(job_name):
options = ['--project=example:example', '--staging_location=gs://foo/bar',
'--temp_location=gs://foo/bar']
if job_name is not None:
options.append('--job_name=' + job_name)
pipeline_options = PipelineOptions(options)
runner = MockRunners.DataflowRunner()
validator = PipelineOptionsValidator(pipeline_options, runner)
return validator
test_cases = [
{'job_name': None, 'errors': []},
{'job_name': '12345', 'errors': ['job_name']},
{'job_name': 'FOO', 'errors': ['job_name']},
{'job_name': 'foo:bar', 'errors': ['job_name']},
{'job_name': 'fo', 'errors': []},
{'job_name': 'foo', 'errors': []},
]
for case in test_cases:
errors = get_validator(case['job_name']).validate()
self.assertEqual(
self.check_errors_for_arguments(errors, case['errors']), [])
def test_num_workers(self):
def get_validator(num_workers):
options = ['--project=example:example', '--job_name=job',
'--staging_location=gs://foo/bar',
'--temp_location=gs://foo/bar']
if num_workers is not None:
options.append('--num_workers=' + num_workers)
pipeline_options = PipelineOptions(options)
runner = MockRunners.DataflowRunner()
validator = PipelineOptionsValidator(pipeline_options, runner)
return validator
test_cases = [
{'num_workers': None, 'errors': []},
{'num_workers': '1', 'errors': []},
{'num_workers': '0', 'errors': ['num_workers']},
{'num_workers': '-1', 'errors': ['num_workers']},
]
for case in test_cases:
errors = get_validator(case['num_workers']).validate()
self.assertEqual(
self.check_errors_for_arguments(errors, case['errors']), [])
def test_is_service_runner(self):
test_cases = [
{
'runner': MockRunners.OtherRunner(),
'options': [],
'expected': False,
},
{
'runner': MockRunners.OtherRunner(),
'options': ['--dataflow_endpoint=https://dataflow.googleapis.com'],
'expected': False,
},
{
'runner': MockRunners.OtherRunner(),
'options': ['--dataflow_endpoint=https://dataflow.googleapis.com/'],
'expected': False,
},
{
'runner': MockRunners.DataflowRunner(),
'options': ['--dataflow_endpoint=https://another.service.com'],
'expected': False,
},
{
'runner': MockRunners.DataflowRunner(),
'options': ['--dataflow_endpoint=https://another.service.com/'],
'expected': False,
},
{
'runner': MockRunners.DataflowRunner(),
'options': ['--dataflow_endpoint=https://dataflow.googleapis.com'],
'expected': True,
},
{
'runner': MockRunners.DataflowRunner(),
'options': ['--dataflow_endpoint=https://dataflow.googleapis.com/'],
'expected': True,
},
{
'runner': MockRunners.DataflowRunner(),
'options': [],
'expected': True,
},
]
for case in test_cases:
validator = PipelineOptionsValidator(
PipelineOptions(case['options']), case['runner'])
self.assertEqual(validator.is_service_runner(), case['expected'])
def test_dataflow_job_file_and_template_location_mutually_exclusive(self):
runner = MockRunners.OtherRunner()
options = PipelineOptions([
'--template_location', 'abc',
'--dataflow_job_file', 'def'
])
validator = PipelineOptionsValidator(options, runner)
errors = validator.validate()
self.assertTrue(errors)
def test_validate_template_location(self):
runner = MockRunners.OtherRunner()
options = PipelineOptions([
'--template_location', 'abc',
])
validator = PipelineOptionsValidator(options, runner)
errors = validator.validate()
self.assertFalse(errors)
def test_validate_dataflow_job_file(self):
runner = MockRunners.OtherRunner()
options = PipelineOptions([
'--dataflow_job_file', 'abc'
])
validator = PipelineOptionsValidator(options, runner)
errors = validator.validate()
self.assertFalse(errors)
def test_test_matcher(self):
def get_validator(matcher):
options = ['--project=example:example',
'--job_name=job',
'--staging_location=gs://foo/bar',
'--temp_location=gs://foo/bar',]
if matcher:
options.append('%s=%s' % ('--on_success_matcher', matcher.decode()))
pipeline_options = PipelineOptions(options)
runner = MockRunners.TestDataflowRunner()
return PipelineOptionsValidator(pipeline_options, runner)
test_case = [
{'on_success_matcher': None,
'errors': []},
{'on_success_matcher': pickler.dumps(AlwaysPassMatcher()),
'errors': []},
{'on_success_matcher': b'abc',
'errors': ['on_success_matcher']},
{'on_success_matcher': pickler.dumps(object),
'errors': ['on_success_matcher']},
]
for case in test_case:
errors = get_validator(case['on_success_matcher']).validate()
self.assertEqual(
self.check_errors_for_arguments(errors, case['errors']), [])
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()