| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Unit tests for the DataflowRunner class.""" |
| # pytype: skip-file |
| |
| import unittest |
| |
| import mock |
| |
| import apache_beam as beam |
| import apache_beam.transforms as ptransform |
| from apache_beam.options.pipeline_options import DebugOptions |
| from apache_beam.options.pipeline_options import GoogleCloudOptions |
| from apache_beam.options.pipeline_options import PipelineOptions |
| from apache_beam.pipeline import AppliedPTransform |
| from apache_beam.pipeline import Pipeline |
| from apache_beam.portability import common_urns |
| from apache_beam.portability import python_urns |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.pvalue import PCollection |
| from apache_beam.runners import DataflowRunner |
| from apache_beam.runners import TestDataflowRunner |
| from apache_beam.runners import create_runner |
| from apache_beam.runners import pipeline_utils |
| from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult |
| from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException |
| from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_options |
| from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_streaming_options |
| from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api |
| from apache_beam.runners.internal import names |
| from apache_beam.runners.runner import PipelineState |
| from apache_beam.testing.extra_assertions import ExtraAssertionsMixin |
| from apache_beam.testing.test_pipeline import TestPipeline |
| from apache_beam.transforms import combiners |
| from apache_beam.transforms import environments |
| from apache_beam.typehints import typehints |
| |
| # Protect against environments where apitools library is not available. |
| # pylint: disable=wrong-import-order, wrong-import-position |
| try: |
| from apache_beam.runners.dataflow.internal import apiclient |
| except ImportError: |
| apiclient = None # type: ignore |
| # pylint: enable=wrong-import-order, wrong-import-position |
| |
| |
| # SpecialParDo and SpecialDoFn are used in test_remote_runner_display_data. |
| # Due to https://github.com/apache/beam/issues/19848, these need to be declared |
| # outside of the test method. |
| # TODO: Should not subclass ParDo. Switch to PTransform as soon as |
| # composite transforms support display data. |
| class SpecialParDo(beam.ParDo): |
| def __init__(self, fn, now): |
| super().__init__(fn) |
| self.fn = fn |
| self.now = now |
| |
| # Make this a list to be accessible within closure |
| def display_data(self): |
| return { |
| 'asubcomponent': self.fn, 'a_class': SpecialParDo, 'a_time': self.now |
| } |
| |
| |
| class SpecialDoFn(beam.DoFn): |
| def display_data(self): |
| return {'dofn_value': 42} |
| |
| def process(self): |
| pass |
| |
| |
| @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') |
| class DataflowRunnerTest(unittest.TestCase, ExtraAssertionsMixin): |
| def setUp(self): |
| self.default_properties = [ |
| '--job_name=test-job', |
| '--project=test-project', |
| '--region=us-central1', |
| '--staging_location=gs://beam/test', |
| '--temp_location=gs://beam/tmp', |
| '--no_auth', |
| '--dry_run=True', |
| '--sdk_location=container' |
| ] |
| |
| @mock.patch('time.sleep', return_value=None) |
| def test_wait_until_finish_unrecognized(self, patched_time_sleep): |
| values_enum = dataflow_api.Job.CurrentStateValueValuesEnum |
| options = PipelineOptions(self.default_properties) |
| |
| class MockDataflowRunner(object): |
| def __init__(self, states): |
| self.dataflow_client = mock.MagicMock() |
| self.job = mock.MagicMock() |
| self.job.id = "test-job-id" |
| self.job.currentState = values_enum.JOB_STATE_UNKNOWN |
| self._states = states |
| self._next_state_index = 0 |
| |
| def get_job_side_effect(*args, **kwargs): |
| self.job.currentState = self._states[self._next_state_index] |
| if self._next_state_index < (len(self._states) - 1): |
| self._next_state_index += 1 |
| return mock.DEFAULT |
| |
| self.dataflow_client.get_job = mock.MagicMock( |
| return_value=self.job, side_effect=get_job_side_effect) |
| self.dataflow_client.list_messages = mock.MagicMock( |
| return_value=([], None)) |
| |
| with self.assertRaisesRegex( |
| AssertionError, |
| r'Job did not reach to a terminal state after waiting indefinitely. ' |
| r'Console URL: ' |
| r'https://console.cloud.google.com/dataflow/jobs/' |
| r'us-central1/test-job-id\?project=test-project'): |
| failed_runner = MockDataflowRunner("some_unrecognized_state") |
| failed_result = DataflowPipelineResult( |
| failed_runner.job, failed_runner, options) |
| failed_result.wait_until_finish() |
| |
| @mock.patch('time.sleep', return_value=None) |
| def test_wait_until_finish(self, patched_time_sleep): |
| values_enum = dataflow_api.Job.CurrentStateValueValuesEnum |
| options = PipelineOptions(self.default_properties) |
| |
| class MockDataflowRunner(object): |
| def __init__(self, states): |
| self.dataflow_client = mock.MagicMock() |
| self.job = mock.MagicMock() |
| self.job.currentState = values_enum.JOB_STATE_UNKNOWN |
| self._states = states |
| self._next_state_index = 0 |
| |
| def get_job_side_effect(*args, **kwargs): |
| self.job.currentState = self._states[self._next_state_index] |
| if self._next_state_index < (len(self._states) - 1): |
| self._next_state_index += 1 |
| return mock.DEFAULT |
| |
| self.dataflow_client.get_job = mock.MagicMock( |
| return_value=self.job, side_effect=get_job_side_effect) |
| self.dataflow_client.list_messages = mock.MagicMock( |
| return_value=([], None)) |
| |
| with self.assertRaisesRegex(DataflowRuntimeException, |
| 'Dataflow pipeline failed. State: FAILED'): |
| failed_runner = MockDataflowRunner([values_enum.JOB_STATE_FAILED]) |
| failed_result = DataflowPipelineResult( |
| failed_runner.job, failed_runner, options) |
| failed_result.wait_until_finish() |
| |
| # check the second call can still triggers the exception |
| with self.assertRaisesRegex(DataflowRuntimeException, |
| 'Dataflow pipeline failed. State: FAILED'): |
| failed_result.wait_until_finish() |
| |
| succeeded_runner = MockDataflowRunner([values_enum.JOB_STATE_DONE]) |
| succeeded_result = DataflowPipelineResult( |
| succeeded_runner.job, succeeded_runner, options) |
| result = succeeded_result.wait_until_finish() |
| self.assertEqual(result, PipelineState.DONE) |
| |
| # Time array has duplicate items, because some logging implementations also |
| # call time. |
| with mock.patch('time.time', mock.MagicMock(side_effect=[1, 1, 2, 2, 3])): |
| duration_succeeded_runner = MockDataflowRunner( |
| [values_enum.JOB_STATE_RUNNING, values_enum.JOB_STATE_DONE]) |
| duration_succeeded_result = DataflowPipelineResult( |
| duration_succeeded_runner.job, duration_succeeded_runner, options) |
| result = duration_succeeded_result.wait_until_finish(5000) |
| self.assertEqual(result, PipelineState.DONE) |
| |
| with mock.patch('time.time', mock.MagicMock(side_effect=[1, 9, 9, 20, 20])): |
| duration_timedout_runner = MockDataflowRunner( |
| [values_enum.JOB_STATE_RUNNING]) |
| duration_timedout_result = DataflowPipelineResult( |
| duration_timedout_runner.job, duration_timedout_runner, options) |
| result = duration_timedout_result.wait_until_finish(5000) |
| self.assertEqual(result, PipelineState.RUNNING) |
| |
| with mock.patch('time.time', mock.MagicMock(side_effect=[1, 1, 2, 2, 3])): |
| with self.assertRaisesRegex(DataflowRuntimeException, |
| 'Dataflow pipeline failed. State: CANCELLED'): |
| duration_failed_runner = MockDataflowRunner( |
| [values_enum.JOB_STATE_CANCELLED]) |
| duration_failed_result = DataflowPipelineResult( |
| duration_failed_runner.job, duration_failed_runner, options) |
| duration_failed_result.wait_until_finish(5000) |
| |
| @mock.patch('time.sleep', return_value=None) |
| def test_cancel(self, patched_time_sleep): |
| values_enum = dataflow_api.Job.CurrentStateValueValuesEnum |
| options = PipelineOptions( |
| self.default_properties).view_as(GoogleCloudOptions) |
| |
| class MockDataflowRunner(object): |
| def __init__(self, state, cancel_result): |
| self.dataflow_client = mock.MagicMock() |
| self.job = mock.MagicMock() |
| self.job.currentState = state |
| |
| self.dataflow_client.get_job = mock.MagicMock(return_value=self.job) |
| self.dataflow_client.modify_job_state = mock.MagicMock( |
| return_value=cancel_result) |
| self.dataflow_client.list_messages = mock.MagicMock( |
| return_value=([], None)) |
| |
| with self.assertRaisesRegex(DataflowRuntimeException, |
| 'Failed to cancel job'): |
| failed_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING, False) |
| failed_result = DataflowPipelineResult( |
| failed_runner.job, failed_runner, options) |
| failed_result.cancel() |
| |
| succeeded_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING, True) |
| succeeded_result = DataflowPipelineResult( |
| succeeded_runner.job, succeeded_runner, options) |
| succeeded_result.cancel() |
| |
| terminal_runner = MockDataflowRunner(values_enum.JOB_STATE_DONE, False) |
| terminal_result = DataflowPipelineResult( |
| terminal_runner.job, terminal_runner, options) |
| terminal_result.cancel() |
| |
| def test_create_runner(self): |
| self.assertTrue(isinstance(create_runner('DataflowRunner'), DataflowRunner)) |
| self.assertTrue( |
| isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner)) |
| |
| @staticmethod |
| def dependency_proto_from_main_session_file(serialized_path): |
| return [ |
| beam_runner_api_pb2.ArtifactInformation( |
| type_urn=common_urns.artifact_types.FILE.urn, |
| type_payload=serialized_path, |
| role_urn=common_urns.artifact_roles.STAGING_TO.urn, |
| role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload( |
| staged_name=names.PICKLED_MAIN_SESSION_FILE).SerializeToString( |
| )) |
| ] |
| |
| def test_environment_override_translation_legacy_worker_harness_image(self): |
| self.default_properties.append('--experiments=beam_fn_api') |
| self.default_properties.append('--worker_harness_container_image=LEGACY') |
| remote_runner = DataflowRunner() |
| options = PipelineOptions(self.default_properties) |
| options.view_as(DebugOptions).add_experiment( |
| 'disable_logging_submission_environment') |
| with Pipeline(remote_runner, options=options) as p: |
| ( # pylint: disable=expression-not-assigned |
| p | ptransform.Create([1, 2, 3]) |
| | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) |
| | ptransform.GroupByKey()) |
| |
| actual = list(remote_runner.proto_pipeline.components.environments.values()) |
| self.assertEqual(len(actual), 1) |
| actual = actual[0] |
| file_path = actual.dependencies[0].type_payload |
| # Dependency payload contains main_session from a transient temp directory |
| # Use actual for expected value. |
| main_session_dep = self.dependency_proto_from_main_session_file(file_path) |
| self.assertEqual( |
| actual, |
| beam_runner_api_pb2.Environment( |
| urn=common_urns.environments.DOCKER.urn, |
| payload=beam_runner_api_pb2.DockerPayload( |
| container_image='LEGACY').SerializeToString(), |
| capabilities=environments.python_sdk_docker_capabilities(), |
| dependencies=environments.python_sdk_dependencies(options=options) + |
| main_session_dep)) |
| |
| def test_environment_override_translation_sdk_container_image(self): |
| self.default_properties.append('--experiments=beam_fn_api') |
| self.default_properties.append('--sdk_container_image=FOO') |
| remote_runner = DataflowRunner() |
| options = PipelineOptions(self.default_properties) |
| options.view_as(DebugOptions).add_experiment( |
| 'disable_logging_submission_environment') |
| with Pipeline(remote_runner, options=options) as p: |
| ( # pylint: disable=expression-not-assigned |
| p | ptransform.Create([1, 2, 3]) |
| | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) |
| | ptransform.GroupByKey()) |
| |
| actual = list(remote_runner.proto_pipeline.components.environments.values()) |
| self.assertEqual(len(actual), 1) |
| actual = actual[0] |
| file_path = actual.dependencies[0].type_payload |
| # Dependency payload contains main_session from a transient temp directory |
| # Use actual for expected value. |
| main_session_dep = self.dependency_proto_from_main_session_file(file_path) |
| self.assertEqual( |
| actual, |
| beam_runner_api_pb2.Environment( |
| urn=common_urns.environments.DOCKER.urn, |
| payload=beam_runner_api_pb2.DockerPayload( |
| container_image='FOO').SerializeToString(), |
| capabilities=environments.python_sdk_docker_capabilities(), |
| dependencies=environments.python_sdk_dependencies(options=options) + |
| main_session_dep)) |
| |
| def test_remote_runner_translation(self): |
| remote_runner = DataflowRunner() |
| with Pipeline(remote_runner, |
| options=PipelineOptions(self.default_properties)) as p: |
| |
| ( # pylint: disable=expression-not-assigned |
| p | ptransform.Create([1, 2, 3]) |
| | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) |
| | ptransform.GroupByKey()) |
| |
| def test_group_by_key_input_visitor_with_valid_inputs(self): |
| p = TestPipeline() |
| pcoll1 = PCollection(p) |
| pcoll2 = PCollection(p) |
| pcoll3 = PCollection(p) |
| |
| pcoll1.element_type = None |
| pcoll2.element_type = typehints.Any |
| pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] |
| for pcoll in [pcoll1, pcoll2, pcoll3]: |
| applied = AppliedPTransform( |
| None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None) |
| applied.outputs[None] = PCollection(None) |
| pipeline_utils.group_by_key_input_visitor().visit_transform(applied) |
| self.assertEqual( |
| pcoll.element_type, typehints.KV[typehints.Any, typehints.Any]) |
| |
| def test_group_by_key_input_visitor_with_invalid_inputs(self): |
| p = TestPipeline() |
| pcoll1 = PCollection(p) |
| pcoll2 = PCollection(p) |
| |
| pcoll1.element_type = str |
| pcoll2.element_type = typehints.Set |
| err_msg = ( |
| r"Input to 'label' must be compatible with KV\[Any, Any\]. " |
| "Found .*") |
| for pcoll in [pcoll1, pcoll2]: |
| with self.assertRaisesRegex(ValueError, err_msg): |
| pipeline_utils.group_by_key_input_visitor().visit_transform( |
| AppliedPTransform( |
| None, beam.GroupByKey(), "label", {'in': pcoll}, None, None)) |
| |
| def test_group_by_key_input_visitor_for_non_gbk_transforms(self): |
| p = TestPipeline() |
| pcoll = PCollection(p) |
| for transform in [beam.Flatten(), beam.Map(lambda x: x)]: |
| pcoll.element_type = typehints.Any |
| pipeline_utils.group_by_key_input_visitor().visit_transform( |
| AppliedPTransform( |
| None, transform, "label", {'in': pcoll}, None, None)) |
| self.assertEqual(pcoll.element_type, typehints.Any) |
| |
| def test_flatten_input_with_visitor_with_single_input(self): |
| self._test_flatten_input_visitor(typehints.KV[int, int], typehints.Any, 1) |
| |
| def test_flatten_input_with_visitor_with_multiple_inputs(self): |
| self._test_flatten_input_visitor( |
| typehints.KV[int, typehints.Any], typehints.Any, 5) |
| |
| def _test_flatten_input_visitor(self, input_type, output_type, num_inputs): |
| p = TestPipeline() |
| inputs = {} |
| for ix in range(num_inputs): |
| input_pcoll = PCollection(p) |
| input_pcoll.element_type = input_type |
| inputs[str(ix)] = input_pcoll |
| output_pcoll = PCollection(p) |
| output_pcoll.element_type = output_type |
| |
| flatten = AppliedPTransform( |
| None, beam.Flatten(), "label", inputs, None, None) |
| flatten.add_output(output_pcoll, None) |
| DataflowRunner.flatten_input_visitor().visit_transform(flatten) |
| for _ in range(num_inputs): |
| self.assertEqual(inputs['0'].element_type, output_type) |
| |
| def test_gbk_then_flatten_input_visitor(self): |
| p = TestPipeline( |
| runner=DataflowRunner(), |
| options=PipelineOptions(self.default_properties)) |
| none_str_pc = p | 'c1' >> beam.Create({None: 'a'}) |
| none_int_pc = p | 'c2' >> beam.Create({None: 3}) |
| flat = (none_str_pc, none_int_pc) | beam.Flatten() |
| _ = flat | beam.GroupByKey() |
| |
| # This may change if type inference changes, but we assert it here |
| # to make sure the check below is not vacuous. |
| self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint) |
| |
| p.visit(pipeline_utils.group_by_key_input_visitor()) |
| p.visit(DataflowRunner.flatten_input_visitor()) |
| |
| # The dataflow runner requires gbk input to be tuples *and* flatten |
| # inputs to be equal to their outputs. Assert both hold. |
| self.assertIsInstance(flat.element_type, typehints.TupleConstraint) |
| self.assertEqual(flat.element_type, none_str_pc.element_type) |
| self.assertEqual(flat.element_type, none_int_pc.element_type) |
| |
| def test_side_input_visitor(self): |
| p = TestPipeline() |
| pc = p | beam.Create([]) |
| |
| transform = beam.Map( |
| lambda x, y, z: (x, y, z), |
| beam.pvalue.AsSingleton(pc), |
| beam.pvalue.AsMultiMap(pc)) |
| applied_transform = AppliedPTransform( |
| None, transform, "label", {'pc': pc}, None, None) |
| DataflowRunner.side_input_visitor().visit_transform(applied_transform) |
| self.assertEqual(2, len(applied_transform.side_inputs)) |
| self.assertEqual( |
| common_urns.side_inputs.ITERABLE.urn, |
| applied_transform.side_inputs[0]._side_input_data().access_pattern) |
| self.assertEqual( |
| common_urns.side_inputs.MULTIMAP.urn, |
| applied_transform.side_inputs[1]._side_input_data().access_pattern) |
| |
| def test_min_cpu_platform_flag_is_propagated_to_experiments(self): |
| remote_runner = DataflowRunner() |
| self.default_properties.append('--min_cpu_platform=Intel Haswell') |
| |
| with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p: |
| p | ptransform.Create([1]) # pylint: disable=expression-not-assigned |
| self.assertIn( |
| 'min_cpu_platform=Intel Haswell', |
| remote_runner.job.options.view_as(DebugOptions).experiments) |
| |
| def test_streaming_adds_windmill_experiments(self): |
| remote_runner = DataflowRunner() |
| self.default_properties.append('--streaming') |
| self.default_properties.append('--experiment=some_other_experiment') |
| |
| with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p: |
| p | ptransform.Create([1]) # pylint: disable=expression-not-assigned |
| |
| experiments_for_job = ( |
| remote_runner.job.options.view_as(DebugOptions).experiments) |
| self.assertIn('enable_streaming_engine', experiments_for_job) |
| self.assertIn('enable_windmill_service', experiments_for_job) |
| self.assertIn('some_other_experiment', experiments_for_job) |
| |
| def test_upload_graph_experiment(self): |
| remote_runner = DataflowRunner() |
| self.default_properties.append('--experiment=upload_graph') |
| |
| with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p: |
| p | ptransform.Create([1]) # pylint: disable=expression-not-assigned |
| |
| experiments_for_job = ( |
| remote_runner.job.options.view_as(DebugOptions).experiments) |
| self.assertIn('upload_graph', experiments_for_job) |
| |
| def test_use_fastavro_experiment_is_not_added_when_use_avro_is_present(self): |
| remote_runner = DataflowRunner() |
| self.default_properties.append('--experiment=use_avro') |
| |
| with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p: |
| p | ptransform.Create([1]) # pylint: disable=expression-not-assigned |
| |
| debug_options = remote_runner.job.options.view_as(DebugOptions) |
| |
| self.assertFalse(debug_options.lookup_experiment('use_fastavro', False)) |
| |
| @mock.patch('os.environ.get', return_value=None) |
| @mock.patch('apache_beam.utils.processes.check_output', return_value=b'') |
| def test_get_default_gcp_region_no_default_returns_none( |
| self, patched_environ, patched_processes): |
| runner = DataflowRunner() |
| result = runner.get_default_gcp_region() |
| self.assertIsNone(result) |
| |
| @mock.patch('os.environ.get', return_value='some-region1') |
| @mock.patch('apache_beam.utils.processes.check_output', return_value=b'') |
| def test_get_default_gcp_region_from_environ( |
| self, patched_environ, patched_processes): |
| runner = DataflowRunner() |
| result = runner.get_default_gcp_region() |
| self.assertEqual(result, 'some-region1') |
| |
| @mock.patch('os.environ.get', return_value=None) |
| @mock.patch( |
| 'apache_beam.utils.processes.check_output', |
| return_value=b'some-region2\n') |
| def test_get_default_gcp_region_from_gcloud( |
| self, patched_environ, patched_processes): |
| runner = DataflowRunner() |
| result = runner.get_default_gcp_region() |
| self.assertEqual(result, 'some-region2') |
| |
| @mock.patch('os.environ.get', return_value=None) |
| @mock.patch( |
| 'apache_beam.utils.processes.check_output', |
| side_effect=RuntimeError('Executable gcloud not found')) |
| def test_get_default_gcp_region_ignores_error( |
| self, patched_environ, patched_processes): |
| runner = DataflowRunner() |
| result = runner.get_default_gcp_region() |
| self.assertIsNone(result) |
| |
| @unittest.skip( |
| 'https://github.com/apache/beam/issues/18716: enable once ' |
| 'CombineFnVisitor is fixed') |
| def test_unsupported_combinefn_detection(self): |
| class CombinerWithNonDefaultSetupTeardown(combiners.CountCombineFn): |
| def setup(self, *args, **kwargs): |
| pass |
| |
| def teardown(self, *args, **kwargs): |
| pass |
| |
| runner = DataflowRunner() |
| with self.assertRaisesRegex(ValueError, |
| 'CombineFn.setup and CombineFn.' |
| 'teardown are not supported'): |
| with beam.Pipeline(runner=runner, |
| options=PipelineOptions(self.default_properties)) as p: |
| _ = ( |
| p | beam.Create([1]) |
| | beam.CombineGlobally(CombinerWithNonDefaultSetupTeardown())) |
| |
| try: |
| with beam.Pipeline(runner=runner, |
| options=PipelineOptions(self.default_properties)) as p: |
| _ = ( |
| p | beam.Create([1]) |
| | beam.CombineGlobally( |
| combiners.SingleInputTupleCombineFn( |
| combiners.CountCombineFn(), combiners.CountCombineFn()))) |
| except ValueError: |
| self.fail('ValueError raised unexpectedly') |
| |
| def test_pack_combiners(self): |
| class PackableCombines(beam.PTransform): |
| def annotations(self): |
| return {python_urns.APPLY_COMBINER_PACKING: b''} |
| |
| def expand(self, pcoll): |
| _ = pcoll | 'PackableMin' >> beam.CombineGlobally(min) |
| _ = pcoll | 'PackableMax' >> beam.CombineGlobally(max) |
| |
| runner = DataflowRunner() |
| with beam.Pipeline(runner=runner, |
| options=PipelineOptions(self.default_properties)) as p: |
| _ = p | beam.Create([10, 20, 30]) | PackableCombines() |
| |
| unpacked_minimum_step_name = ( |
| 'PackableCombines/PackableMin/CombinePerKey/Combine') |
| unpacked_maximum_step_name = ( |
| 'PackableCombines/PackableMax/CombinePerKey/Combine') |
| packed_step_name = ( |
| 'PackableCombines/Packed[PackableMin_CombinePerKey, ' |
| 'PackableMax_CombinePerKey]/Pack') |
| transform_names = set( |
| transform.unique_name |
| for transform in runner.proto_pipeline.components.transforms.values()) |
| self.assertNotIn(unpacked_minimum_step_name, transform_names) |
| self.assertNotIn(unpacked_maximum_step_name, transform_names) |
| self.assertIn(packed_step_name, transform_names) |
| |
| def test_batch_is_runner_v2(self): |
| options = PipelineOptions(['--sdk_location=container']) |
| _check_and_add_missing_options(options) |
| for expected in ['beam_fn_api', |
| 'use_unified_worker', |
| 'use_runner_v2', |
| 'use_portable_job_submission']: |
| self.assertTrue( |
| options.view_as(DebugOptions).lookup_experiment(expected, False), |
| expected) |
| |
| def test_streaming_is_runner_v2(self): |
| options = PipelineOptions(['--sdk_location=container', '--streaming']) |
| _check_and_add_missing_options(options) |
| _check_and_add_missing_streaming_options(options) |
| for expected in ['beam_fn_api', |
| 'use_unified_worker', |
| 'use_runner_v2', |
| 'use_portable_job_submission', |
| 'enable_windmill_service', |
| 'enable_streaming_engine']: |
| self.assertTrue( |
| options.view_as(DebugOptions).lookup_experiment(expected, False), |
| expected) |
| |
| def test_dataflow_service_options_enable_prime_sets_runner_v2(self): |
| options = PipelineOptions([ |
| '--sdk_location=container', |
| '--streaming', |
| '--dataflow_service_options=enable_prime' |
| ]) |
| _check_and_add_missing_options(options) |
| for expected in ['beam_fn_api', |
| 'use_unified_worker', |
| 'use_runner_v2', |
| 'use_portable_job_submission']: |
| self.assertTrue( |
| options.view_as(DebugOptions).lookup_experiment(expected, False), |
| expected) |
| |
| options = PipelineOptions([ |
| '--sdk_location=container', |
| '--streaming', |
| '--dataflow_service_options=enable_prime' |
| ]) |
| _check_and_add_missing_options(options) |
| _check_and_add_missing_streaming_options(options) |
| for expected in ['beam_fn_api', |
| 'use_unified_worker', |
| 'use_runner_v2', |
| 'use_portable_job_submission', |
| 'enable_windmill_service', |
| 'enable_streaming_engine']: |
| self.assertTrue( |
| options.view_as(DebugOptions).lookup_experiment(expected, False), |
| expected) |
| |
| @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') |
| @mock.patch( |
| 'apache_beam.options.pipeline_options.GoogleCloudOptions.validate', |
| lambda *args: []) |
| def test_auto_streaming_with_unbounded(self): |
| options = PipelineOptions([ |
| '--sdk_location=container', |
| '--runner=DataflowRunner', |
| '--dry_run=True', |
| '--temp_location=gs://bucket', |
| '--project=project', |
| '--region=region' |
| ]) |
| with beam.Pipeline(options=options) as p: |
| _ = p | beam.io.ReadFromPubSub('projects/some-project/topics/some-topic') |
| self.assertEqual( |
| p.result.job.proto.type, |
| apiclient.dataflow.Job.TypeValueValuesEnum.JOB_TYPE_STREAMING) |
| |
| @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') |
| @mock.patch( |
| 'apache_beam.options.pipeline_options.GoogleCloudOptions.validate', |
| lambda *args: []) |
| def test_auto_streaming_no_unbounded(self): |
| options = PipelineOptions([ |
| '--sdk_location=container', |
| '--runner=DataflowRunner', |
| '--dry_run=True', |
| '--temp_location=gs://bucket', |
| '--project=project', |
| '--region=region' |
| ]) |
| with beam.Pipeline(options=options) as p: |
| _ = p | beam.Create([1, 2, 3]) |
| self.assertEqual( |
| p.result.job.proto.type, |
| apiclient.dataflow.Job.TypeValueValuesEnum.JOB_TYPE_BATCH) |
| |
| @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') |
| @mock.patch( |
| 'apache_beam.options.pipeline_options.GoogleCloudOptions.validate', |
| lambda *args: []) |
| def test_explicit_streaming_no_unbounded(self): |
| options = PipelineOptions([ |
| '--streaming', |
| '--sdk_location=container', |
| '--runner=DataflowRunner', |
| '--dry_run=True', |
| '--temp_location=gs://bucket', |
| '--project=project', |
| '--region=region' |
| ]) |
| with beam.Pipeline(options=options) as p: |
| _ = p | beam.Create([1, 2, 3]) |
| self.assertEqual( |
| p.result.job.proto.type, |
| apiclient.dataflow.Job.TypeValueValuesEnum.JOB_TYPE_STREAMING) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |