| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Unit tests for the apiclient module.""" |
| |
| # pytype: skip-file |
| |
| import json |
| import logging |
| import sys |
| import unittest |
| |
| import mock |
| |
| from apache_beam.metrics.cells import DistributionData |
| from apache_beam.options.pipeline_options import GoogleCloudOptions |
| from apache_beam.options.pipeline_options import PipelineOptions |
| from apache_beam.pipeline import Pipeline |
| from apache_beam.portability import common_urns |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.runners.dataflow.internal import names |
| from apache_beam.runners.dataflow.internal.clients import dataflow |
| from apache_beam.transforms import Create |
| from apache_beam.transforms import DataflowDistributionCounter |
| from apache_beam.transforms import DoFn |
| from apache_beam.transforms import ParDo |
| from apache_beam.transforms.environments import DockerEnvironment |
| |
| # Protect against environments where apitools library is not available. |
| # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports |
| try: |
| from apache_beam.runners.dataflow.internal import apiclient |
| except ImportError: |
| apiclient = None # type: ignore |
| # pylint: enable=wrong-import-order, wrong-import-position |
| |
| FAKE_PIPELINE_URL = "gs://invalid-bucket/anywhere" |
| _LOGGER = logging.getLogger(__name__) |
| |
| |
| @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') |
| class UtilTest(unittest.TestCase): |
| @unittest.skip("Enable once BEAM-1080 is fixed.") |
| def test_create_application_client(self): |
| pipeline_options = PipelineOptions() |
| apiclient.DataflowApplicationClient(pipeline_options) |
| |
| def test_pipeline_url(self): |
| pipeline_options = PipelineOptions([ |
| '--subnetwork', |
| '/regions/MY/subnetworks/SUBNETWORK', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| env = apiclient.Environment( |
| [], |
| pipeline_options, |
| '2.0.0', # any environment version |
| FAKE_PIPELINE_URL) |
| |
| recovered_options = None |
| for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties: |
| if additionalProperty.key == 'options': |
| recovered_options = additionalProperty.value |
| break |
| else: |
| self.fail( |
| 'No pipeline options found in %s' % env.proto.sdkPipelineOptions) |
| |
| pipeline_url = None |
| for property in recovered_options.object_value.properties: |
| if property.key == 'pipelineUrl': |
| pipeline_url = property.value |
| break |
| else: |
| self.fail('No pipeline_url found in %s' % recovered_options) |
| |
| self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL) |
| |
| def test_set_network(self): |
| pipeline_options = PipelineOptions([ |
| '--network', |
| 'anetworkname', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual(env.proto.workerPools[0].network, 'anetworkname') |
| |
| def test_set_subnetwork(self): |
| pipeline_options = PipelineOptions([ |
| '--subnetwork', |
| '/regions/MY/subnetworks/SUBNETWORK', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].subnetwork, |
| '/regions/MY/subnetworks/SUBNETWORK') |
| |
| def test_flexrs_blank(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual(env.proto.flexResourceSchedulingGoal, None) |
| |
| def test_flexrs_cost(self): |
| pipeline_options = PipelineOptions([ |
| '--flexrs_goal', |
| 'COST_OPTIMIZED', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.flexResourceSchedulingGoal, |
| ( |
| dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum. |
| FLEXRS_COST_OPTIMIZED)) |
| |
| def test_flexrs_speed(self): |
| pipeline_options = PipelineOptions([ |
| '--flexrs_goal', |
| 'SPEED_OPTIMIZED', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.flexResourceSchedulingGoal, |
| ( |
| dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum. |
| FLEXRS_SPEED_OPTIMIZED)) |
| |
| def test_default_environment_get_set(self): |
| |
| pipeline_options = PipelineOptions([ |
| '--experiments=beam_fn_api', |
| '--experiments=use_unified_worker', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| pipeline = Pipeline(options=pipeline_options) |
| pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned |
| |
| test_environment = DockerEnvironment(container_image='test_default_image') |
| proto_pipeline, _ = pipeline.to_runner_api( |
| return_context=True, default_environment=test_environment) |
| |
| dummy_env = beam_runner_api_pb2.Environment( |
| urn=common_urns.environments.DOCKER.urn, |
| payload=( |
| beam_runner_api_pb2.DockerPayload( |
| container_image='dummy_image')).SerializeToString()) |
| proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env) |
| |
| dummy_transform = beam_runner_api_pb2.PTransform( |
| environment_id='dummy_env_id') |
| proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom( |
| dummy_transform) |
| |
| env = apiclient.Environment( |
| [], # packages |
| pipeline_options, |
| '2.0.0', # any environment version |
| FAKE_PIPELINE_URL, |
| proto_pipeline, |
| _sdk_image_overrides={ |
| '.*dummy.*': 'dummy_image', '.*test.*': 'test_default_image' |
| }) |
| worker_pool = env.proto.workerPools[0] |
| |
| self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages)) |
| |
| images_from_proto = [ |
| sdk_info.containerImage |
| for sdk_info in worker_pool.sdkHarnessContainerImages |
| ] |
| self.assertIn('test_default_image', images_from_proto) |
| |
| def test_sdk_harness_container_image_overrides(self): |
| test_environment = DockerEnvironment( |
| container_image='dummy_container_image') |
| proto_pipeline, _ = Pipeline().to_runner_api( |
| return_context=True, default_environment=test_environment) |
| |
| pipeline_options = PipelineOptions([ |
| '--experiments=beam_fn_api', |
| '--experiments=use_unified_worker', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| # Accessing non-public method for testing. |
| apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( |
| proto_pipeline, |
| { |
| '.*dummy.*': 'new_dummy_container_image', |
| '.*notfound.*': 'new_dummy_container_image_2' |
| }, |
| pipeline_options) |
| |
| self.assertIsNotNone(1, len(proto_pipeline.components.environments)) |
| env = list(proto_pipeline.components.environments.values())[0] |
| |
| from apache_beam.utils import proto_utils |
| docker_payload = proto_utils.parse_Bytes( |
| env.payload, beam_runner_api_pb2.DockerPayload) |
| |
| # Container image should be overridden by a the given override. |
| self.assertEqual( |
| docker_payload.container_image, 'new_dummy_container_image') |
| |
| def test_dataflow_container_image_override(self): |
| pipeline_options = PipelineOptions([ |
| '--experiments=beam_fn_api', |
| '--experiments=use_unified_worker', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| pipeline = Pipeline(options=pipeline_options) |
| pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned |
| |
| dummy_env = DockerEnvironment( |
| container_image='apache/beam_dummy_name:dummy_tag') |
| proto_pipeline, _ = pipeline.to_runner_api( |
| return_context=True, default_environment=dummy_env) |
| |
| # Accessing non-public method for testing. |
| apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( |
| proto_pipeline, dict(), pipeline_options) |
| |
| from apache_beam.utils import proto_utils |
| found_override = False |
| for env in proto_pipeline.components.environments.values(): |
| docker_payload = proto_utils.parse_Bytes( |
| env.payload, beam_runner_api_pb2.DockerPayload) |
| if docker_payload.container_image.startswith( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): |
| found_override = True |
| |
| self.assertTrue(found_override) |
| |
| def test_non_apache_container_not_overridden(self): |
| pipeline_options = PipelineOptions([ |
| '--experiments=beam_fn_api', |
| '--experiments=use_unified_worker', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| |
| pipeline = Pipeline(options=pipeline_options) |
| pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned |
| |
| dummy_env = DockerEnvironment( |
| container_image='other_org/dummy_name:dummy_tag') |
| proto_pipeline, _ = pipeline.to_runner_api( |
| return_context=True, default_environment=dummy_env) |
| |
| # Accessing non-public method for testing. |
| apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( |
| proto_pipeline, dict(), pipeline_options) |
| |
| self.assertIsNotNone(2, len(proto_pipeline.components.environments)) |
| |
| from apache_beam.utils import proto_utils |
| found_override = False |
| for env in proto_pipeline.components.environments.values(): |
| docker_payload = proto_utils.parse_Bytes( |
| env.payload, beam_runner_api_pb2.DockerPayload) |
| if docker_payload.container_image.startswith( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): |
| found_override = True |
| |
| self.assertFalse(found_override) |
| |
| def test_pipeline_sdk_not_overridden(self): |
| pipeline_options = PipelineOptions([ |
| '--experiments=beam_fn_api', |
| '--experiments=use_unified_worker', |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag' |
| ]) |
| |
| pipeline = Pipeline(options=pipeline_options) |
| pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned |
| |
| proto_pipeline, _ = pipeline.to_runner_api(return_context=True) |
| |
| dummy_env = DockerEnvironment( |
| container_image='dummy_prefix/dummy_name:dummy_tag') |
| proto_pipeline, _ = pipeline.to_runner_api( |
| return_context=True, default_environment=dummy_env) |
| |
| # Accessing non-public method for testing. |
| apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( |
| proto_pipeline, dict(), pipeline_options) |
| |
| self.assertIsNotNone(2, len(proto_pipeline.components.environments)) |
| |
| from apache_beam.utils import proto_utils |
| found_override = False |
| for env in proto_pipeline.components.environments.values(): |
| docker_payload = proto_utils.parse_Bytes( |
| env.payload, beam_runner_api_pb2.DockerPayload) |
| if docker_payload.container_image.startswith( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): |
| found_override = True |
| |
| self.assertFalse(found_override) |
| |
| def test_invalid_default_job_name(self): |
| # Regexp for job names in dataflow. |
| regexp = '^[a-z]([-a-z0-9]{0,61}[a-z0-9])?$' |
| |
| job_name = apiclient.Job._build_default_job_name('invalid.-_user_n*/ame') |
| self.assertRegex(job_name, regexp) |
| |
| job_name = apiclient.Job._build_default_job_name( |
| 'invalid-extremely-long.username_that_shouldbeshortened_or_is_invalid') |
| self.assertRegex(job_name, regexp) |
| |
| def test_default_job_name(self): |
| job_name = apiclient.Job.default_job_name(None) |
| regexp = 'beamapp-.*-[0-9]{10}-[0-9]{6}' |
| self.assertRegex(job_name, regexp) |
| |
| def test_split_int(self): |
| number = 12345 |
| split_number = apiclient.to_split_int(number) |
| self.assertEqual((split_number.lowBits, split_number.highBits), (number, 0)) |
| shift_number = number << 32 |
| split_number = apiclient.to_split_int(shift_number) |
| self.assertEqual((split_number.lowBits, split_number.highBits), (0, number)) |
| |
| def test_translate_distribution_using_accumulator(self): |
| metric_update = dataflow.CounterUpdate() |
| accumulator = mock.Mock() |
| accumulator.min = 1 |
| accumulator.max = 15 |
| accumulator.sum = 16 |
| accumulator.count = 2 |
| apiclient.translate_distribution(accumulator, metric_update) |
| self.assertEqual(metric_update.distribution.min.lowBits, accumulator.min) |
| self.assertEqual(metric_update.distribution.max.lowBits, accumulator.max) |
| self.assertEqual(metric_update.distribution.sum.lowBits, accumulator.sum) |
| self.assertEqual( |
| metric_update.distribution.count.lowBits, accumulator.count) |
| |
| def test_translate_distribution_using_distribution_data(self): |
| metric_update = dataflow.CounterUpdate() |
| distribution_update = DistributionData(16, 2, 1, 15) |
| apiclient.translate_distribution(distribution_update, metric_update) |
| self.assertEqual( |
| metric_update.distribution.min.lowBits, distribution_update.min) |
| self.assertEqual( |
| metric_update.distribution.max.lowBits, distribution_update.max) |
| self.assertEqual( |
| metric_update.distribution.sum.lowBits, distribution_update.sum) |
| self.assertEqual( |
| metric_update.distribution.count.lowBits, distribution_update.count) |
| |
| def test_translate_distribution_using_dataflow_distribution_counter(self): |
| counter_update = DataflowDistributionCounter() |
| counter_update.add_input(1) |
| counter_update.add_input(3) |
| metric_proto = dataflow.CounterUpdate() |
| apiclient.translate_distribution(counter_update, metric_proto) |
| histogram = mock.Mock(firstBucketOffset=None, bucketCounts=None) |
| counter_update.translate_to_histogram(histogram) |
| self.assertEqual(metric_proto.distribution.min.lowBits, counter_update.min) |
| self.assertEqual(metric_proto.distribution.max.lowBits, counter_update.max) |
| self.assertEqual(metric_proto.distribution.sum.lowBits, counter_update.sum) |
| self.assertEqual( |
| metric_proto.distribution.count.lowBits, counter_update.count) |
| self.assertEqual( |
| metric_proto.distribution.histogram.bucketCounts, |
| histogram.bucketCounts) |
| self.assertEqual( |
| metric_proto.distribution.histogram.firstBucketOffset, |
| histogram.firstBucketOffset) |
| |
| def test_translate_means(self): |
| metric_update = dataflow.CounterUpdate() |
| accumulator = mock.Mock() |
| accumulator.sum = 16 |
| accumulator.count = 2 |
| apiclient.MetricUpdateTranslators.translate_scalar_mean_int( |
| accumulator, metric_update) |
| self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum) |
| self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count) |
| |
| accumulator.sum = 16.0 |
| accumulator.count = 2 |
| apiclient.MetricUpdateTranslators.translate_scalar_mean_float( |
| accumulator, metric_update) |
| self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum) |
| self.assertEqual( |
| metric_update.floatingPointMean.count.lowBits, accumulator.count) |
| |
| def test_translate_means_using_distribution_accumulator(self): |
| # This is the special case for MeanByteCount. |
| # Which is reported over the FnAPI as a beam distribution, |
| # and to the service as a MetricUpdate IntegerMean. |
| metric_update = dataflow.CounterUpdate() |
| accumulator = mock.Mock() |
| accumulator.min = 7 |
| accumulator.max = 9 |
| accumulator.sum = 16 |
| accumulator.count = 2 |
| apiclient.MetricUpdateTranslators.translate_scalar_mean_int( |
| accumulator, metric_update) |
| self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum) |
| self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count) |
| |
| accumulator.sum = 16.0 |
| accumulator.count = 2 |
| apiclient.MetricUpdateTranslators.translate_scalar_mean_float( |
| accumulator, metric_update) |
| self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum) |
| self.assertEqual( |
| metric_update.floatingPointMean.count.lowBits, accumulator.count) |
| |
| def test_default_ip_configuration(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| env = apiclient.Environment([], |
| pipeline_options, |
| '2.0.0', |
| FAKE_PIPELINE_URL) |
| self.assertEqual(env.proto.workerPools[0].ipConfiguration, None) |
| |
| def test_public_ip_configuration(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--use_public_ips']) |
| env = apiclient.Environment([], |
| pipeline_options, |
| '2.0.0', |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].ipConfiguration, |
| dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC) |
| |
| def test_private_ip_configuration(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--no_use_public_ips']) |
| env = apiclient.Environment([], |
| pipeline_options, |
| '2.0.0', |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].ipConfiguration, |
| dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE) |
| |
| def test_number_of_worker_harness_threads(self): |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--number_of_worker_harness_threads', |
| '2' |
| ]) |
| env = apiclient.Environment([], |
| pipeline_options, |
| '2.0.0', |
| FAKE_PIPELINE_URL) |
| self.assertEqual(env.proto.workerPools[0].numThreadsPerWorker, 2) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_harness_override_default_in_released_sdks(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| override = ''.join([ |
| 'runner_harness_container_image=', |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY, |
| '/harness:2.2.0' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertIn(override, env.proto.experiments) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_harness_override_absent_in_released_sdks_with_runner_v2(self): |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--streaming', |
| '--experiments=use_runner_v2' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| if env.proto.experiments: |
| for experiment in env.proto.experiments: |
| self.assertNotIn('runner_harness_container_image=', experiment) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_harness_override_custom_in_released_sdks(self): |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--streaming', |
| '--experiments=runner_harness_container_image=fake_image' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| 1, |
| len([ |
| x for x in env.proto.experiments |
| if x.startswith('runner_harness_container_image=') |
| ])) |
| self.assertIn( |
| 'runner_harness_container_image=fake_image', env.proto.experiments) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_harness_override_custom_in_released_sdks_with_runner_v2(self): |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--streaming', |
| '--experiments=runner_harness_container_image=fake_image', |
| '--experiments=use_runner_v2', |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| 1, |
| len([ |
| x for x in env.proto.experiments |
| if x.startswith('runner_harness_container_image=') |
| ])) |
| self.assertIn( |
| 'runner_harness_container_image=fake_image', env.proto.experiments) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0.rc1') |
| def test_harness_override_uses_base_version_in_rc_releases(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| override = ''.join([ |
| 'runner_harness_container_image=', |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY, |
| '/harness:2.2.0' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertIn(override, env.proto.experiments) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0.dev') |
| def test_harness_override_absent_in_unreleased_sdk(self): |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| if env.proto.experiments: |
| for experiment in env.proto.experiments: |
| self.assertNotIn('runner_harness_container_image=', experiment) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0.dev') |
| def test_pinned_worker_harness_image_tag_used_in_dev_sdk(self): |
| # streaming, fnapi pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d-fnapi:%s' % |
| ( |
| sys.version_info[0], |
| sys.version_info[1], |
| names.BEAM_FNAPI_CONTAINER_VERSION))) |
| |
| # batch, legacy pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:%s' % ( |
| sys.version_info[0], |
| sys.version_info[1], |
| names.BEAM_CONTAINER_VERSION))) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_worker_harness_image_tag_matches_released_sdk_version(self): |
| # streaming, fnapi pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + |
| '/python%d%d-fnapi:2.2.0' % |
| (sys.version_info[0], sys.version_info[1]))) |
| |
| # batch, legacy pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' % |
| (sys.version_info[0], sys.version_info[1]))) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0.rc1') |
| def test_worker_harness_image_tag_matches_base_sdk_version_of_an_rc(self): |
| # streaming, fnapi pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp', '--streaming']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + |
| '/python%d%d-fnapi:2.2.0' % |
| (sys.version_info[0], sys.version_info[1]))) |
| |
| # batch, legacy pipeline. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, |
| ( |
| names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' % |
| (sys.version_info[0], sys.version_info[1]))) |
| |
| def test_worker_harness_override_takes_precedence_over_sdk_defaults(self): |
| # streaming, fnapi pipeline. |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--streaming', |
| '--sdk_container_image=some:image' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, 'some:image') |
| # batch, legacy pipeline. |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://any-location/temp', |
| '--sdk_container_image=some:image' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.workerPools[0].workerHarnessContainerImage, 'some:image') |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.Job.' |
| 'job_id_for_name', |
| return_value='test_id') |
| def test_transform_name_mapping(self, mock_job): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--update', |
| '--transform_name_mapping', |
| '{\"from\":\"to\"}' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| self.assertIsNotNone(job.proto.transformNameMapping) |
| |
| def test_created_from_snapshot_id(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--create_from_snapshot', |
| 'test_snapshot_id' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| self.assertEqual('test_snapshot_id', job.proto.createdFromSnapshotId) |
| |
| def test_labels(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| self.assertIsNone(job.proto.labels) |
| |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--label', |
| 'key1=value1', |
| '--label', |
| 'key2', |
| '--label', |
| 'key3=value3', |
| '--labels', |
| 'key4=value4', |
| '--labels', |
| 'key5' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| self.assertEqual(5, len(job.proto.labels.additionalProperties)) |
| self.assertEqual('key1', job.proto.labels.additionalProperties[0].key) |
| self.assertEqual('value1', job.proto.labels.additionalProperties[0].value) |
| self.assertEqual('key2', job.proto.labels.additionalProperties[1].key) |
| self.assertEqual('', job.proto.labels.additionalProperties[1].value) |
| self.assertEqual('key3', job.proto.labels.additionalProperties[2].key) |
| self.assertEqual('value3', job.proto.labels.additionalProperties[2].value) |
| self.assertEqual('key4', job.proto.labels.additionalProperties[3].key) |
| self.assertEqual('value4', job.proto.labels.additionalProperties[3].value) |
| self.assertEqual('key5', job.proto.labels.additionalProperties[4].key) |
| self.assertEqual('', job.proto.labels.additionalProperties[4].value) |
| |
| def test_experiment_use_multiple_sdk_containers(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'beam_fn_api' |
| ]) |
| environment = apiclient.Environment([], |
| pipeline_options, |
| 1, |
| FAKE_PIPELINE_URL) |
| self.assertIn('use_multiple_sdk_containers', environment.proto.experiments) |
| |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'beam_fn_api', |
| '--experiments', |
| 'use_multiple_sdk_containers' |
| ]) |
| environment = apiclient.Environment([], |
| pipeline_options, |
| 1, |
| FAKE_PIPELINE_URL) |
| self.assertIn('use_multiple_sdk_containers', environment.proto.experiments) |
| |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'beam_fn_api', |
| '--experiments', |
| 'no_use_multiple_sdk_containers' |
| ]) |
| environment = apiclient.Environment([], |
| pipeline_options, |
| 1, |
| FAKE_PIPELINE_URL) |
| self.assertNotIn( |
| 'use_multiple_sdk_containers', environment.proto.experiments) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (3, 8)) |
| def test_get_python_sdk_name(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'beam_fn_api', |
| '--experiments', |
| 'use_multiple_sdk_containers' |
| ]) |
| environment = apiclient.Environment([], |
| pipeline_options, |
| 1, |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| 'Apache Beam Python 3.8 SDK', environment._get_python_sdk_name()) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (2, 7)) |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_interpreter_version_check_fails_py27(self): |
| pipeline_options = PipelineOptions([]) |
| self.assertRaises( |
| Exception, |
| apiclient._verify_interpreter_version_is_supported, |
| pipeline_options) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (3, 0, 0)) |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0.dev') |
| def test_interpreter_version_check_passes_on_dev_sdks(self): |
| pipeline_options = PipelineOptions([]) |
| apiclient._verify_interpreter_version_is_supported(pipeline_options) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (3, 0, 0)) |
| def test_interpreter_version_check_passes_with_experiment(self): |
| pipeline_options = PipelineOptions( |
| ["--experiment=use_unsupported_python_version"]) |
| apiclient._verify_interpreter_version_is_supported(pipeline_options) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (3, 8, 2)) |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_interpreter_version_check_passes_py38(self): |
| pipeline_options = PipelineOptions([]) |
| apiclient._verify_interpreter_version_is_supported(pipeline_options) |
| |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', |
| (3, 9, 0)) |
| @mock.patch( |
| 'apache_beam.runners.dataflow.internal.apiclient.' |
| 'beam_version.__version__', |
| '2.2.0') |
| def test_interpreter_version_check_fails_on_not_yet_supported_version(self): |
| pipeline_options = PipelineOptions([]) |
| self.assertRaises( |
| Exception, |
| apiclient._verify_interpreter_version_is_supported, |
| pipeline_options) |
| |
| def test_use_unified_worker(self): |
| pipeline_options = PipelineOptions([]) |
| self.assertFalse(apiclient._use_unified_worker(pipeline_options)) |
| |
| pipeline_options = PipelineOptions(['--experiments=beam_fn_api']) |
| self.assertFalse(apiclient._use_unified_worker(pipeline_options)) |
| |
| pipeline_options = PipelineOptions(['--experiments=use_unified_worker']) |
| self.assertTrue(apiclient._use_unified_worker(pipeline_options)) |
| |
| pipeline_options = PipelineOptions( |
| ['--experiments=use_unified_worker', '--experiments=beam_fn_api']) |
| self.assertTrue(apiclient._use_unified_worker(pipeline_options)) |
| |
| pipeline_options = PipelineOptions( |
| ['--experiments=use_runner_v2', '--experiments=beam_fn_api']) |
| self.assertTrue(apiclient._use_unified_worker(pipeline_options)) |
| |
| pipeline_options = PipelineOptions([ |
| '--experiments=use_unified_worker', |
| '--experiments=use_runner_v2', |
| '--experiments=beam_fn_api' |
| ]) |
| self.assertTrue(apiclient._use_unified_worker(pipeline_options)) |
| |
| def test_get_response_encoding(self): |
| encoding = apiclient.get_response_encoding() |
| |
| assert encoding == 'utf8' |
| |
| def test_graph_is_uploaded(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'beam_fn_api', |
| '--experiments', |
| 'upload_graph' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| pipeline_options.view_as(GoogleCloudOptions).no_auth = True |
| client = apiclient.DataflowApplicationClient(pipeline_options) |
| with mock.patch.object(client, 'stage_file', side_effect=None): |
| with mock.patch.object(client, 'create_job_description', |
| side_effect=None): |
| with mock.patch.object(client, |
| 'submit_job_description', |
| side_effect=None): |
| client.create_job(job) |
| client.stage_file.assert_called_once_with( |
| mock.ANY, "dataflow_graph.json", mock.ANY) |
| client.create_job_description.assert_called_once() |
| |
| def test_create_job_returns_existing_job(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| self.assertTrue(job.proto.clientRequestId) # asserts non-empty string |
| pipeline_options.view_as(GoogleCloudOptions).no_auth = True |
| client = apiclient.DataflowApplicationClient(pipeline_options) |
| |
| response = dataflow.Job() |
| # different clientRequestId from `job` |
| response.clientRequestId = "20210821081910123456-1234" |
| response.name = 'test_job_name' |
| response.id = '2021-08-19_21_18_43-9756917246311111021' |
| |
| with mock.patch.object(client._client.projects_locations_jobs, |
| 'Create', |
| side_effect=[response]): |
| with mock.patch.object(client, 'create_job_description', |
| side_effect=None): |
| with self.assertRaises( |
| apiclient.DataflowJobAlreadyExistsError) as context: |
| client.create_job(job) |
| |
| self.assertEqual( |
| str(context.exception), |
| 'There is already active job named %s with id: %s. If you want to ' |
| 'submit a second job, try again by setting a different name using ' |
| '--job_name.' % ('test_job_name', response.id)) |
| |
| def test_update_job_returns_existing_job(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--region', |
| 'us-central1', |
| '--update', |
| ]) |
| replace_job_id = '2021-08-21_00_00_01-6081497447916622336' |
| with mock.patch('apache_beam.runners.dataflow.internal.apiclient.Job.' |
| 'job_id_for_name', |
| return_value=replace_job_id) as job_id_for_name_mock: |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| job_id_for_name_mock.assert_called_once() |
| |
| self.assertTrue(job.proto.clientRequestId) # asserts non-empty string |
| |
| pipeline_options.view_as(GoogleCloudOptions).no_auth = True |
| client = apiclient.DataflowApplicationClient(pipeline_options) |
| |
| response = dataflow.Job() |
| # different clientRequestId from `job` |
| response.clientRequestId = "20210821083254123456-1234" |
| response.name = 'test_job_name' |
| response.id = '2021-08-19_21_29_07-5725551945600207770' |
| |
| with mock.patch.object(client, 'create_job_description', side_effect=None): |
| with mock.patch.object(client._client.projects_locations_jobs, |
| 'Create', |
| side_effect=[response]): |
| |
| with self.assertRaises( |
| apiclient.DataflowJobAlreadyExistsError) as context: |
| client.create_job(job) |
| |
| self.assertEqual( |
| str(context.exception), |
| 'The job named %s with id: %s has already been updated into job ' |
| 'id: %s and cannot be updated again.' % |
| ('test_job_name', replace_job_id, response.id)) |
| |
| def test_template_file_generation_with_upload_graph(self): |
| pipeline_options = PipelineOptions([ |
| '--project', |
| 'test_project', |
| '--job_name', |
| 'test_job_name', |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--experiments', |
| 'upload_graph', |
| '--template_location', |
| 'gs://test-location/template' |
| ]) |
| job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL) |
| job.proto.steps.append(dataflow.Step(name='test_step_name')) |
| |
| pipeline_options.view_as(GoogleCloudOptions).no_auth = True |
| client = apiclient.DataflowApplicationClient(pipeline_options) |
| with mock.patch.object(client, 'stage_file', side_effect=None): |
| with mock.patch.object(client, 'create_job_description', |
| side_effect=None): |
| with mock.patch.object(client, |
| 'submit_job_description', |
| side_effect=None): |
| client.create_job(job) |
| |
| client.stage_file.assert_has_calls([ |
| mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY), |
| mock.call(mock.ANY, 'template', mock.ANY) |
| ]) |
| client.create_job_description.assert_called_once() |
| # template is generated, but job should not be submitted to the |
| # service. |
| client.submit_job_description.assert_not_called() |
| |
| template_filename = client.stage_file.call_args_list[-1][0][1] |
| self.assertTrue('template' in template_filename) |
| template_content = client.stage_file.call_args_list[-1][0][2].read( |
| ).decode('utf-8') |
| template_obj = json.loads(template_content) |
| self.assertFalse(template_obj.get('steps')) |
| self.assertTrue(template_obj['stepsLocation']) |
| |
| def test_stage_resources(self): |
| pipeline_options = PipelineOptions([ |
| '--temp_location', |
| 'gs://test-location/temp', |
| '--staging_location', |
| 'gs://test-location/staging', |
| '--no_auth' |
| ]) |
| pipeline = beam_runner_api_pb2.Pipeline( |
| components=beam_runner_api_pb2.Components( |
| environments={ |
| 'env1': beam_runner_api_pb2.Environment( |
| dependencies=[ |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.FILE.urn, |
| type_payload=beam_runner_api_pb2. |
| ArtifactFilePayload( |
| path='/tmp/foo1').SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='foo1').SerializeToString()), |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.FILE.urn, |
| type_payload=beam_runner_api_pb2. |
| ArtifactFilePayload( |
| path='/tmp/bar1').SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='bar1').SerializeToString()) |
| ]), |
| 'env2': beam_runner_api_pb2.Environment( |
| dependencies=[ |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.FILE.urn, |
| type_payload=beam_runner_api_pb2. |
| ArtifactFilePayload( |
| path='/tmp/foo2').SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='foo2').SerializeToString()), |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.FILE.urn, |
| type_payload=beam_runner_api_pb2. |
| ArtifactFilePayload( |
| path='/tmp/bar2').SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='bar2').SerializeToString()) |
| ]) |
| })) |
| client = apiclient.DataflowApplicationClient(pipeline_options) |
| with mock.patch.object(apiclient._LegacyDataflowStager, |
| 'stage_job_resources') as mock_stager: |
| client._stage_resources(pipeline, pipeline_options) |
| mock_stager.assert_called_once_with( |
| [('/tmp/foo1', 'foo1'), ('/tmp/bar1', 'bar1'), ('/tmp/foo2', 'foo2'), |
| ('/tmp/bar2', 'bar2')], |
| staging_location='gs://test-location/staging') |
| |
| pipeline_expected = beam_runner_api_pb2.Pipeline( |
| components=beam_runner_api_pb2.Components( |
| environments={ |
| 'env1': beam_runner_api_pb2.Environment( |
| dependencies=[ |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.URL.urn, |
| type_payload=beam_runner_api_pb2.ArtifactUrlPayload( |
| url='gs://test-location/staging/foo1' |
| ).SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='foo1').SerializeToString()), |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.URL.urn, |
| type_payload=beam_runner_api_pb2.ArtifactUrlPayload( |
| url='gs://test-location/staging/bar1'). |
| SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='bar1').SerializeToString()) |
| ]), |
| 'env2': beam_runner_api_pb2.Environment( |
| dependencies=[ |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.URL.urn, |
| type_payload=beam_runner_api_pb2.ArtifactUrlPayload( |
| url='gs://test-location/staging/foo2'). |
| SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='foo2').SerializeToString()), |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.URL.urn, |
| type_payload=beam_runner_api_pb2.ArtifactUrlPayload( |
| url='gs://test-location/staging/bar2'). |
| SerializeToString(), |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2. |
| ArtifactStagingToRolePayload( |
| staged_name='bar2').SerializeToString()) |
| ]) |
| })) |
| self.assertEqual(pipeline, pipeline_expected) |
| |
| def test_set_dataflow_service_option(self): |
| pipeline_options = PipelineOptions([ |
| '--dataflow_service_option', |
| 'whizz=bang', |
| '--temp_location', |
| 'gs://any-location/temp' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual(env.proto.serviceOptions, ['whizz=bang']) |
| |
| def test_enable_hot_key_logging(self): |
| # Tests that the enable_hot_key_logging is not set by default. |
| pipeline_options = PipelineOptions( |
| ['--temp_location', 'gs://any-location/temp']) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertIsNone(env.proto.debugOptions) |
| |
| # Now test that it is set when given. |
| pipeline_options = PipelineOptions([ |
| '--enable_hot_key_logging', '--temp_location', 'gs://any-location/temp' |
| ]) |
| env = apiclient.Environment( |
| [], #packages |
| pipeline_options, |
| '2.0.0', #any environment version |
| FAKE_PIPELINE_URL) |
| self.assertEqual( |
| env.proto.debugOptions, dataflow.DebugOptions(enableHotKeyLogging=True)) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |