blob: 21163fc121c50882b2109c0aebe5095d435dfada [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for apache_beam.runners.interactive.interactive_beam."""
# pytype: skip-file
import dataclasses
import importlib
import sys
import time
import unittest
from concurrent.futures import TimeoutError
from typing import NamedTuple
from unittest.mock import ANY
from unittest.mock import MagicMock
from unittest.mock import call
from unittest.mock import patch
import apache_beam as beam
from apache_beam import dataframe as frames
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.options.pipeline_options import FlinkRunnerOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.interactive import interactive_beam as ib
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import interactive_runner as ir
from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager
from apache_beam.runners.interactive.dataproc.types import ClusterMetadata
from apache_beam.runners.interactive.options.capture_limiters import Limiter
from apache_beam.runners.interactive.recording_manager import AsyncComputationResult
from apache_beam.runners.interactive.testing.mock_env import isolated_env
from apache_beam.runners.runner import PipelineState
from apache_beam.testing.test_stream import TestStream
@dataclasses.dataclass
class MockClusterMetadata:
master_url = 'mock_url'
class Record(NamedTuple):
order_id: int
product_id: int
quantity: int
# The module name is also a variable in module.
_module_name = 'apache_beam.runners.interactive.interactive_beam_test'
def _get_watched_pcollections_with_variable_names():
watched_pcollections = {}
for watching in ie.current_env().watching():
for key, val in watching:
if hasattr(val, '__class__') and isinstance(val, beam.pvalue.PCollection):
watched_pcollections[val] = key
return watched_pcollections
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@isolated_env
class InteractiveBeamTest(unittest.TestCase):
def setUp(self):
self._var_in_class_instance = 'a var in class instance, not directly used'
def tearDown(self):
ib.options.capture_control.set_limiters_for_test([])
def test_watch_main_by_default(self):
test_env = ie.InteractiveEnvironment()
# Current Interactive Beam env fetched and the test env are 2 instances.
self.assertNotEqual(id(ie.current_env()), id(test_env))
self.assertEqual(ie.current_env().watching(), test_env.watching())
def test_watch_a_module_by_name(self):
test_env = ie.InteractiveEnvironment()
ib.watch(_module_name)
test_env.watch(_module_name)
self.assertEqual(ie.current_env().watching(), test_env.watching())
def test_watch_a_module_by_module_object(self):
test_env = ie.InteractiveEnvironment()
module = importlib.import_module(_module_name)
ib.watch(module)
test_env.watch(module)
self.assertEqual(ie.current_env().watching(), test_env.watching())
def test_watch_locals(self):
# test_env serves as local var too.
test_env = ie.InteractiveEnvironment()
ib.watch(locals())
test_env.watch(locals())
self.assertEqual(ie.current_env().watching(), test_env.watching())
def test_watch_class_instance(self):
test_env = ie.InteractiveEnvironment()
ib.watch(self)
test_env.watch(self)
self.assertEqual(ie.current_env().watching(), test_env.watching())
@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
def test_show_always_watch_given_pcolls(self):
p = beam.Pipeline(ir.InteractiveRunner())
# pylint: disable=bad-option-value
pcoll = p | 'Create' >> beam.Create(range(10))
# The pcoll is not watched since watch(locals()) is not explicitly called.
self.assertFalse(pcoll in _get_watched_pcollections_with_variable_names())
# The call of show watches pcoll.
ib.watch({'p': p})
ie.current_env().track_user_pipelines()
ib.show(pcoll)
self.assertTrue(pcoll in _get_watched_pcollections_with_variable_names())
@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
def test_show_mark_pcolls_computed_when_done(self):
p = beam.Pipeline(ir.InteractiveRunner())
# pylint: disable=bad-option-value
pcoll = p | 'Create' >> beam.Create(range(10))
self.assertFalse(pcoll in ie.current_env().computed_pcollections)
# The call of show marks pcoll computed.
ib.watch(locals())
ie.current_env().track_user_pipelines()
ib.show(pcoll)
self.assertTrue(pcoll in ie.current_env().computed_pcollections)
@patch((
'apache_beam.runners.interactive.interactive_beam.'
'visualize_computed_pcoll'))
def test_show_handles_dict_of_pcolls(self, mocked_visualize):
p = beam.Pipeline(ir.InteractiveRunner())
# pylint: disable=bad-option-value
pcoll = p | 'Create' >> beam.Create(range(10))
ib.watch(locals())
ie.current_env().track_user_pipelines()
ie.current_env().mark_pcollection_computed([pcoll])
ie.current_env()._is_in_ipython = True
ie.current_env()._is_in_notebook = True
ib.show({'pcoll': pcoll})
mocked_visualize.assert_called_once()
@patch((
'apache_beam.runners.interactive.interactive_beam.'
'visualize_computed_pcoll'))
def test_show_handles_iterable_of_pcolls(self, mocked_visualize):
p = beam.Pipeline(ir.InteractiveRunner())
# pylint: disable=bad-option-value
pcoll = p | 'Create' >> beam.Create(range(10))
ib.watch(locals())
ie.current_env().track_user_pipelines()
ie.current_env().mark_pcollection_computed([pcoll])
ie.current_env()._is_in_ipython = True
ie.current_env()._is_in_notebook = True
ib.show([pcoll])
mocked_visualize.assert_called_once()
@patch('apache_beam.runners.interactive.interactive_beam.visualize')
def test_show_handles_deferred_dataframes(self, mocked_visualize):
p = beam.Pipeline(ir.InteractiveRunner())
deferred = frames.convert.to_dataframe(p | beam.Create([Record(0, 0, 0)]))
ib.watch(locals())
ie.current_env().track_user_pipelines()
ie.current_env()._is_in_ipython = True
ie.current_env()._is_in_notebook = True
ib.show(deferred)
mocked_visualize.assert_called_once()
@patch((
'apache_beam.runners.interactive.interactive_beam.'
'visualize_computed_pcoll'))
def test_show_noop_when_pcoll_container_is_invalid(self, mocked_visualize):
class SomeRandomClass:
def __init__(self, pcoll):
self._pcoll = pcoll
p = beam.Pipeline(ir.InteractiveRunner())
# pylint: disable=bad-option-value
pcoll = p | 'Create' >> beam.Create(range(10))
ie.current_env().mark_pcollection_computed([pcoll])
ie.current_env()._is_in_ipython = True
ie.current_env()._is_in_notebook = True
self.assertRaises(ValueError, ib.show, SomeRandomClass(pcoll))
mocked_visualize.assert_not_called()
def test_recordings_describe(self):
"""Tests that getting the description works."""
# Create the pipelines to test.
p1 = beam.Pipeline(ir.InteractiveRunner())
p2 = beam.Pipeline(ir.InteractiveRunner())
ib.watch(locals())
# Get the descriptions. This test is simple as there isn't much logic in the
# method.
self.assertEqual(ib.recordings.describe(p1)['size'], 0)
self.assertEqual(ib.recordings.describe(p2)['size'], 0)
all_descriptions = ib.recordings.describe()
self.assertEqual(all_descriptions[p1]['size'], 0)
self.assertEqual(all_descriptions[p2]['size'], 0)
# Ensure that the variable name for the pipeline is set correctly.
self.assertEqual(all_descriptions[p1]['pipeline_var'], 'p1')
self.assertEqual(all_descriptions[p2]['pipeline_var'], 'p2')
def test_recordings_clear(self):
"""Tests that clearing the pipeline is correctly forwarded."""
# Create a basic pipeline to store something in the cache.
p = beam.Pipeline(ir.InteractiveRunner())
elem = p | beam.Create([1])
ib.watch(locals())
ie.current_env().track_user_pipelines()
# This records the pipeline so that the cache size is > 0.
ib.collect(elem)
self.assertGreater(ib.recordings.describe(p)['size'], 0)
# After clearing, the cache should be empty.
ib.recordings.clear(p)
self.assertEqual(ib.recordings.describe(p)['size'], 0)
def test_recordings_record(self):
"""Tests that recording pipeline succeeds."""
# Add the TestStream so that it can be cached.
ib.options.recordable_sources.add(TestStream)
# Create a pipeline with an arbitrary amonunt of elements.
p = beam.Pipeline(
ir.InteractiveRunner(), options=PipelineOptions(streaming=True))
# pylint: disable=unused-variable
_ = (p
| TestStream()
.advance_watermark_to(0)
.advance_processing_time(1)
.add_elements(list(range(10)))
.advance_processing_time(1)) # yapf: disable
ib.watch(locals())
ie.current_env().track_user_pipelines()
# Assert that the pipeline starts in a good state.
self.assertEqual(ib.recordings.describe(p)['state'], PipelineState.STOPPED)
self.assertEqual(ib.recordings.describe(p)['size'], 0)
# Create a lmiter that stops the background caching job when something is
# written to cache. This is used to make ensure that the pipeline is
# functioning properly and that there are no data races with the test.
class SizeLimiter(Limiter):
def __init__(self, pipeline):
self.pipeline = pipeline
self.should_trigger = False
def is_triggered(self):
return (
ib.recordings.describe(self.pipeline)['size'] > 0 and
self.should_trigger)
limiter = SizeLimiter(p)
ib.options.capture_control.set_limiters_for_test([limiter])
# Assert that a recording can be started only once.
self.assertTrue(ib.recordings.record(p))
self.assertFalse(ib.recordings.record(p))
self.assertEqual(ib.recordings.describe(p)['state'], PipelineState.RUNNING)
# Wait for the pipeline to start and write something to cache.
limiter.should_trigger = True
for _ in range(60):
if limiter.is_triggered():
break
time.sleep(1)
self.assertTrue(
limiter.is_triggered(),
'Test timed out waiting for limiter to be triggered. This indicates '
'that the BackgroundCachingJob did not cache anything.')
# Assert that a recording can be stopped and can't be started again until
# after the cache is cleared.
ib.recordings.stop(p)
self.assertEqual(ib.recordings.describe(p)['state'], PipelineState.STOPPED)
self.assertFalse(ib.recordings.record(p))
ib.recordings.clear(p)
self.assertTrue(ib.recordings.record(p))
ib.recordings.stop(p)
def test_collect_raw_records_true(self):
p = beam.Pipeline(ir.InteractiveRunner())
data = list(range(5))
pcoll = p | 'Create' >> beam.Create(data)
ib.watch(locals())
ie.current_env().track_user_pipelines()
result = ib.collect(pcoll, raw_records=True)
self.assertIsInstance(result, list)
self.assertEqual(result, data)
result_n = ib.collect(pcoll, n=3, raw_records=True)
self.assertIsInstance(result_n, list)
self.assertEqual(result_n, data[:3])
def test_collect_raw_records_false(self):
p = beam.Pipeline(ir.InteractiveRunner())
data = list(range(5))
pcoll = p | 'Create' >> beam.Create(data)
ib.watch(locals())
ie.current_env().track_user_pipelines()
result = ib.collect(pcoll)
self.assertNotIsInstance(result, list)
self.assertTrue(
hasattr(result, 'columns'), "Result should have 'columns' attribute")
self.assertTrue(
hasattr(result, 'values'), "Result should have 'values' attribute")
result_n = ib.collect(pcoll, n=3)
self.assertNotIsInstance(result_n, list)
self.assertTrue(
hasattr(result_n, 'columns'),
"Result (n=3) should have 'columns' attribute")
self.assertTrue(
hasattr(result_n, 'values'),
"Result (n=3) should have 'values' attribute")
def test_collect_raw_records_true_multiple_pcolls(self):
p = beam.Pipeline(ir.InteractiveRunner())
data1 = list(range(3))
data2 = [x * x for x in range(3)]
pcoll1 = p | 'Create1' >> beam.Create(data1)
pcoll2 = p | 'Create2' >> beam.Create(data2)
ib.watch(locals())
ie.current_env().track_user_pipelines()
result = ib.collect(pcoll1, pcoll2, raw_records=True)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], list)
self.assertEqual(result[0], data1)
self.assertIsInstance(result[1], list)
self.assertEqual(result[1], data2)
def test_collect_raw_records_false_multiple_pcolls(self):
p = beam.Pipeline(ir.InteractiveRunner())
data1 = list(range(3))
data2 = [x * x for x in range(3)]
pcoll1 = p | 'Create1' >> beam.Create(data1)
pcoll2 = p | 'Create2' >> beam.Create(data2)
ib.watch(locals())
ie.current_env().track_user_pipelines()
result = ib.collect(pcoll1, pcoll2)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertNotIsInstance(result[0], list)
self.assertTrue(hasattr(result[0], 'columns'))
self.assertNotIsInstance(result[1], list)
self.assertTrue(hasattr(result[1], 'columns'))
def test_collect_raw_records_true_force_tuple(self):
p = beam.Pipeline(ir.InteractiveRunner())
data = list(range(5))
pcoll = p | 'Create' >> beam.Create(data)
ib.watch(locals())
ie.current_env().track_user_pipelines()
result = ib.collect(pcoll, raw_records=True, force_tuple=True)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], list)
self.assertEqual(result[0], data)
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@isolated_env
class InteractiveBeamClustersTest(unittest.TestCase):
def setUp(self):
self.current_env.options.cache_root = 'gs://fake'
self.clusters = self.current_env.clusters
def tearDown(self):
self.current_env.options.cache_root = None
def test_cluster_metadata_pass_through_metadata(self):
cid = ClusterMetadata(project_id='test-project')
meta = self.clusters.cluster_metadata(cid)
self.assertIs(meta, cid)
def test_cluster_metadata_identifies_pipeline(self):
cid = beam.Pipeline()
known_meta = ClusterMetadata(project_id='test-project')
dcm = DataprocClusterManager(known_meta)
self.clusters.pipelines[cid] = dcm
meta = self.clusters.cluster_metadata(cid)
self.assertIs(meta, known_meta)
def test_cluster_metadata_identifies_master_url(self):
cid = 'test-url'
known_meta = ClusterMetadata(project_id='test-project')
_ = DataprocClusterManager(known_meta)
self.clusters.master_urls[cid] = known_meta
meta = self.clusters.cluster_metadata(cid)
self.assertIs(meta, known_meta)
def test_cluster_metadata_default_value(self):
cid_none = None
cid_unknown_p = beam.Pipeline()
cid_unknown_master_url = 'test-url'
default_meta = ClusterMetadata(project_id='test-project')
self.clusters.set_default_cluster(default_meta)
self.assertIs(default_meta, self.clusters.cluster_metadata(cid_none))
self.assertIs(default_meta, self.clusters.cluster_metadata(cid_unknown_p))
self.assertIs(
default_meta, self.clusters.cluster_metadata(cid_unknown_master_url))
def test_create_a_new_cluster(self):
meta = ClusterMetadata(project_id='test-project')
_ = self.clusters.create(meta)
# Derived fields are populated.
self.assertTrue(meta.master_url.startswith('test-url'))
self.assertEqual(meta.dashboard, 'test-dashboard')
# The cluster is known.
self.assertIn(meta, self.clusters.dataproc_cluster_managers)
self.assertIn(meta.master_url, self.clusters.master_urls)
# The default cluster is updated to the created cluster.
self.assertIs(meta, self.clusters.default_cluster_metadata)
def test_create_but_reuse_a_known_cluster(self):
known_meta = ClusterMetadata(
project_id='test-project', region='test-region')
known_dcm = DataprocClusterManager(known_meta)
known_meta.master_url = 'test-url'
self.clusters.set_default_cluster(known_meta)
self.clusters.dataproc_cluster_managers[known_meta] = known_dcm
self.clusters.master_urls[known_meta.master_url] = known_meta
# Use an equivalent meta as the identifier to create a cluster.
cid_meta = ClusterMetadata(
project_id=known_meta.project_id,
region=known_meta.region,
cluster_name=known_meta.cluster_name)
dcm = self.clusters.create(cid_meta)
# The known cluster manager is returned.
self.assertIs(dcm, known_dcm)
# Then use an equivalent master_url as the identifier.
cid_master_url = known_meta.master_url
dcm = self.clusters.create(cid_master_url)
self.assertIs(dcm, known_dcm)
def test_cleanup_by_a_pipeline(self):
meta = ClusterMetadata(project_id='test-project')
dcm = self.clusters.create(meta)
# Set up the association between a pipeline and a cluster.
# In real code, it's set by the runner the 1st time a pipeline is executed.
options = PipelineOptions()
options.view_as(FlinkRunnerOptions).flink_master = meta.master_url
p = beam.Pipeline(options=options)
self.clusters.pipelines[p] = dcm
dcm.pipelines.add(p)
self.clusters.cleanup(p)
# Delete the cluster.
self.m_delete_cluster.assert_called_once()
# Pipeline association is cleaned up.
self.assertNotIn(p, self.clusters.pipelines)
self.assertNotIn(p, dcm.pipelines)
# The internal option in the pipeline is overwritten.
self.assertEqual(
p.options.view_as(FlinkRunnerOptions).flink_master, '[auto]')
# The original option is unchanged.
self.assertEqual(
options.view_as(FlinkRunnerOptions).flink_master, meta.master_url)
# The cluster is unknown now.
self.assertNotIn(meta, self.clusters.dataproc_cluster_managers)
self.assertNotIn(meta.master_url, self.clusters.master_urls)
# The cleaned up cluster is also the default cluster. Clean the default.
self.assertIsNone(self.clusters.default_cluster_metadata)
def test_not_cleanup_if_multiple_pipelines_share_a_manager(self):
meta = ClusterMetadata(project_id='test-project')
dcm = self.clusters.create(meta)
options = PipelineOptions()
options.view_as(FlinkRunnerOptions).flink_master = meta.master_url
options2 = PipelineOptions()
options2.view_as(FlinkRunnerOptions).flink_master = meta.master_url
p = beam.Pipeline(options=options)
p2 = beam.Pipeline(options=options2)
self.clusters.pipelines[p] = dcm
self.clusters.pipelines[p2] = dcm
dcm.pipelines.add(p)
dcm.pipelines.add(p2)
self.clusters.cleanup(p)
# No cluster deleted.
self.m_delete_cluster.assert_not_called()
# Pipeline association of p is cleaned up.
self.assertNotIn(p, self.clusters.pipelines)
self.assertNotIn(p, dcm.pipelines)
# The internal option in the pipeline is overwritten.
self.assertEqual(
p.options.view_as(FlinkRunnerOptions).flink_master, '[auto]')
# The original option is unchanged.
self.assertEqual(
options.view_as(FlinkRunnerOptions).flink_master, meta.master_url)
# Pipeline association of p2 still presents.
self.assertIn(p2, self.clusters.pipelines)
self.assertIn(p2, dcm.pipelines)
self.assertEqual(
p2.options.view_as(FlinkRunnerOptions).flink_master, meta.master_url)
self.assertEqual(
options2.view_as(FlinkRunnerOptions).flink_master, meta.master_url)
# The cluster is still known.
self.assertIn(meta, self.clusters.dataproc_cluster_managers)
self.assertIn(meta.master_url, self.clusters.master_urls)
# The default cluster still presents.
self.assertIs(meta, self.clusters.default_cluster_metadata)
def test_cleanup_by_a_master_url(self):
meta = ClusterMetadata(project_id='test-project')
_ = self.clusters.create(meta)
self.clusters.cleanup(meta.master_url)
self.m_delete_cluster.assert_called_once()
self.assertNotIn(meta, self.clusters.dataproc_cluster_managers)
self.assertNotIn(meta.master_url, self.clusters.master_urls)
self.assertIsNone(self.clusters.default_cluster_metadata)
def test_cleanup_by_meta(self):
known_meta = ClusterMetadata(
project_id='test-project', region='test-region')
_ = self.clusters.create(known_meta)
meta = ClusterMetadata(
project_id=known_meta.project_id,
region=known_meta.region,
cluster_name=known_meta.cluster_name)
self.clusters.cleanup(meta)
self.m_delete_cluster.assert_called_once()
self.assertNotIn(known_meta, self.clusters.dataproc_cluster_managers)
self.assertNotIn(known_meta.master_url, self.clusters.master_urls)
self.assertIsNone(self.clusters.default_cluster_metadata)
def test_force_cleanup_everything(self):
meta = ClusterMetadata(project_id='test-project')
meta2 = ClusterMetadata(project_id='test-project-2')
_ = self.clusters.create(meta)
_ = self.clusters.create(meta2)
self.clusters.cleanup(force=True)
self.assertEqual(self.m_delete_cluster.call_count, 2)
self.assertNotIn(meta, self.clusters.dataproc_cluster_managers)
self.assertNotIn(meta2, self.clusters.dataproc_cluster_managers)
self.assertIsNone(self.clusters.default_cluster_metadata)
def test_cleanup_noop_for_no_cluster_identifier(self):
meta = ClusterMetadata(project_id='test-project')
_ = self.clusters.create(meta)
self.clusters.cleanup()
self.m_delete_cluster.assert_not_called()
def test_cleanup_noop_unknown_cluster(self):
meta = ClusterMetadata(project_id='test-project')
dcm = self.clusters.create(meta)
p = beam.Pipeline()
self.clusters.pipelines[p] = dcm
dcm.pipelines.add(p)
cid_pipeline = beam.Pipeline()
self.clusters.cleanup(cid_pipeline)
self.m_delete_cluster.assert_not_called()
cid_master_url = 'some-random-url'
self.clusters.cleanup(cid_master_url)
self.m_delete_cluster.assert_not_called()
cid_meta = ClusterMetadata(project_id='random-project')
self.clusters.cleanup(cid_meta)
self.m_delete_cluster.assert_not_called()
self.assertIn(meta, self.clusters.dataproc_cluster_managers)
self.assertIn(meta.master_url, self.clusters.master_urls)
self.assertIs(meta, self.clusters.default_cluster_metadata)
self.assertIn(p, self.clusters.pipelines)
self.assertIn(p, dcm.pipelines)
def test_describe_everything(self):
meta = ClusterMetadata(project_id='test-project')
meta2 = ClusterMetadata(
project_id='test-project', region='some-other-region')
_ = self.clusters.create(meta)
_ = self.clusters.create(meta2)
meta_list = self.clusters.describe()
self.assertEqual([meta, meta2], meta_list)
def test_describe_by_cluster_identifier(self):
known_meta = ClusterMetadata(project_id='test-project')
known_meta2 = ClusterMetadata(
project_id='test-project', region='some-other-region')
dcm = self.clusters.create(known_meta)
dcm2 = self.clusters.create(known_meta2)
p = beam.Pipeline()
p2 = beam.Pipeline()
self.clusters.pipelines[p] = dcm
dcm.pipelines.add(p)
self.clusters.pipelines[p2] = dcm2
dcm.pipelines.add(p2)
cid_pipeline = p
meta = self.clusters.describe(cid_pipeline)
self.assertIs(meta, known_meta)
cid_master_url = known_meta.master_url
meta = self.clusters.describe(cid_master_url)
self.assertIs(meta, known_meta)
cid_meta = ClusterMetadata(
project_id=known_meta.project_id,
region=known_meta.region,
cluster_name=known_meta.cluster_name)
meta = self.clusters.describe(cid_meta)
self.assertIs(meta, known_meta)
def test_describe_everything_when_cluster_identifer_unknown(self):
known_meta = ClusterMetadata(project_id='test-project')
known_meta2 = ClusterMetadata(
project_id='test-project', region='some-other-region')
dcm = self.clusters.create(known_meta)
dcm2 = self.clusters.create(known_meta2)
p = beam.Pipeline()
p2 = beam.Pipeline()
self.clusters.pipelines[p] = dcm
dcm.pipelines.add(p)
self.clusters.pipelines[p2] = dcm2
dcm.pipelines.add(p2)
cid_pipeline = beam.Pipeline()
meta_list = self.clusters.describe(cid_pipeline)
self.assertEqual([known_meta, known_meta2], meta_list)
cid_master_url = 'some-random-url'
meta_list = self.clusters.describe(cid_master_url)
self.assertEqual([known_meta, known_meta2], meta_list)
cid_meta = ClusterMetadata(project_id='some-random-project')
meta_list = self.clusters.describe(cid_meta)
self.assertEqual([known_meta, known_meta2], meta_list)
def test_default_value_for_invalid_worker_number(self):
meta = ClusterMetadata(project_id='test-project', num_workers=1)
self.clusters.create(meta)
self.assertEqual(meta.num_workers, 2)
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@isolated_env
class InteractiveBeamComputeTest(unittest.TestCase):
def setUp(self):
self.env = ie.current_env()
self.env._is_in_ipython = False # Default to non-IPython
def test_compute_blocking(self):
p = beam.Pipeline(ir.InteractiveRunner())
data = list(range(10))
pcoll = p | 'Create' >> beam.Create(data)
ib.watch(locals())
self.env.track_user_pipelines()
result = ib.compute(pcoll, blocking=True)
self.assertIsNone(result) # Blocking returns None
self.assertTrue(pcoll in self.env.computed_pcollections)
collected = ib.collect(pcoll, raw_records=True)
self.assertEqual(collected, data)
def test_compute_non_blocking(self):
p = beam.Pipeline(ir.InteractiveRunner())
data = list(range(5))
pcoll = p | 'Create' >> beam.Create(data)
ib.watch(locals())
self.env.track_user_pipelines()
async_result = ib.compute(pcoll, blocking=False)
self.assertIsInstance(async_result, AsyncComputationResult)
pipeline_result = async_result.result(timeout=60)
self.assertTrue(async_result.done())
self.assertIsNone(async_result.exception())
self.assertEqual(pipeline_result.state, PipelineState.DONE)
self.assertTrue(pcoll in self.env.computed_pcollections)
collected = ib.collect(pcoll, raw_records=True)
self.assertEqual(collected, data)
def test_compute_with_list_input(self):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
ib.watch(locals())
self.env.track_user_pipelines()
ib.compute([pcoll1, pcoll2], blocking=True)
self.assertTrue(pcoll1 in self.env.computed_pcollections)
self.assertTrue(pcoll2 in self.env.computed_pcollections)
def test_compute_with_dict_input(self):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
ib.watch(locals())
self.env.track_user_pipelines()
ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True)
self.assertTrue(pcoll1 in self.env.computed_pcollections)
self.assertTrue(pcoll2 in self.env.computed_pcollections)
def test_compute_empty_input(self):
result = ib.compute([], blocking=True)
self.assertIsNone(result)
result_async = ib.compute([], blocking=False)
self.assertIsNone(result_async)
def test_compute_force_recompute(self):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll = p | 'Create' >> beam.Create([1, 2, 3])
ib.watch(locals())
self.env.track_user_pipelines()
ib.compute(pcoll, blocking=True)
self.assertTrue(pcoll in self.env.computed_pcollections)
# Mock evict_computed_pcollections to check if it's called
with patch.object(self.env, 'evict_computed_pcollections') as mock_evict:
ib.compute(pcoll, blocking=True, force_compute=True)
mock_evict.assert_called_once_with(p)
self.assertTrue(pcoll in self.env.computed_pcollections)
def test_compute_non_blocking_exception(self):
p = beam.Pipeline(ir.InteractiveRunner())
def raise_error(elem):
raise ValueError('Test Error')
pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error)
ib.watch(locals())
self.env.track_user_pipelines()
async_result = ib.compute(pcoll, blocking=False)
self.assertIsInstance(async_result, AsyncComputationResult)
with self.assertRaises(ValueError):
async_result.result(timeout=60)
self.assertTrue(async_result.done())
self.assertIsInstance(async_result.exception(), ValueError)
self.assertFalse(pcoll in self.env.computed_pcollections)
@patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
@patch('apache_beam.runners.interactive.recording_manager.display')
@patch('apache_beam.runners.interactive.recording_manager.clear_output')
@patch('apache_beam.runners.interactive.recording_manager.HTML')
@patch('ipywidgets.Button')
@patch('ipywidgets.FloatProgress')
@patch('ipywidgets.Output')
@patch('ipywidgets.HBox')
@patch('ipywidgets.VBox')
def test_compute_non_blocking_ipython_widgets(
self,
mock_vbox,
mock_hbox,
mock_output,
mock_progress,
mock_button,
mock_html,
mock_clear_output,
mock_display,
):
self.env._is_in_ipython = True
p = beam.Pipeline(ir.InteractiveRunner())
pcoll = p | 'Create' >> beam.Create(range(3))
ib.watch(locals())
self.env.track_user_pipelines()
mock_controls = mock_vbox.return_value
mock_html_instance = mock_html.return_value
async_result = ib.compute(pcoll, blocking=False)
self.assertIsNotNone(async_result)
mock_button.assert_called_once_with(description='Cancel')
mock_progress.assert_called_once()
mock_output.assert_called_once()
mock_hbox.assert_called_once()
mock_vbox.assert_called_once()
mock_html.assert_called_once_with('<p>Initializing...</p>')
self.assertEqual(mock_display.call_count, 2)
mock_display.assert_has_calls([
call(mock_controls, display_id=async_result._display_id),
call(mock_html_instance)
])
mock_clear_output.assert_called_once()
async_result.result(timeout=60) # Let it finish
def test_compute_dependency_wait_true(self):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
ib.watch(locals())
self.env.track_user_pipelines()
rm = self.env.get_recording_manager(p)
# Start pcoll1 computation
async_res1 = ib.compute(pcoll1, blocking=False)
self.assertTrue(self.env.is_pcollection_computing(pcoll1))
# Spy on _wait_for_dependencies
with patch.object(rm,
'_wait_for_dependencies',
wraps=rm._wait_for_dependencies) as spy_wait:
async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True)
# Check that wait_for_dependencies was called for pcoll2
spy_wait.assert_called_with({pcoll2}, async_res2)
# Let pcoll1 finish
async_res1.result(timeout=60)
self.assertTrue(pcoll1 in self.env.computed_pcollections)
self.assertFalse(self.env.is_pcollection_computing(pcoll1))
# pcoll2 should now run and complete
async_res2.result(timeout=60)
self.assertTrue(pcoll2 in self.env.computed_pcollections)
@patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing')
def test_compute_dependency_wait_false(self, mock_is_computing):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
ib.watch(locals())
self.env.track_user_pipelines()
rm = self.env.get_recording_manager(p)
# Pretend pcoll1 is computing
mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1
with patch.object(rm,
'_execute_pipeline_fragment',
wraps=rm._execute_pipeline_fragment) as spy_execute:
async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False)
async_res2.result(timeout=60)
# Assert that execute was called for pcoll2 without waiting
spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY)
self.assertTrue(pcoll2 in self.env.computed_pcollections)
def test_async_computation_result_cancel(self):
p = beam.Pipeline(ir.InteractiveRunner())
# A stream that never finishes to test cancellation
pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100))
ib.watch(locals())
self.env.track_user_pipelines()
async_result = ib.compute(pcoll, blocking=False)
self.assertIsInstance(async_result, AsyncComputationResult)
# Give it a moment to start
time.sleep(0.1)
# Mock the pipeline result's cancel method
mock_pipeline_result = MagicMock()
mock_pipeline_result.state = PipelineState.RUNNING
async_result.set_pipeline_result(mock_pipeline_result)
self.assertTrue(async_result.cancel())
mock_pipeline_result.cancel.assert_called_once()
# The future should be cancelled eventually by the runner
# This part is hard to test without deeper runner integration
with self.assertRaises(TimeoutError):
async_result.result(timeout=1) # It should not complete successfully
@patch(
'apache_beam.runners.interactive.recording_manager.RecordingManager.'
'_execute_pipeline_fragment')
def test_compute_multiple_async(self, mock_execute_fragment):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)
ib.watch(locals())
self.env.track_user_pipelines()
mock_pipeline_result = MagicMock()
mock_pipeline_result.state = PipelineState.DONE
mock_execute_fragment.return_value = mock_pipeline_result
res1 = ib.compute(pcoll1, blocking=False)
res2 = ib.compute(pcoll2, blocking=False)
res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1
self.assertIsNotNone(res1)
self.assertIsNotNone(res2)
self.assertIsNotNone(res3)
res1.result(timeout=60)
res2.result(timeout=60)
res3.result(timeout=60)
time.sleep(0.1)
self.assertTrue(
pcoll1 in self.env.computed_pcollections, "pcoll1 not marked computed")
self.assertTrue(
pcoll2 in self.env.computed_pcollections, "pcoll2 not marked computed")
self.assertTrue(
pcoll3 in self.env.computed_pcollections, "pcoll3 not marked computed")
self.assertEqual(mock_execute_fragment.call_count, 3)
@patch(
'apache_beam.runners.interactive.interactive_beam.'
'deferred_df_to_pcollection')
def test_compute_input_flattening(self, mock_deferred_to_pcoll):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'C1' >> beam.Create([1])
pcoll2 = p | 'C2' >> beam.Create([2])
pcoll3 = p | 'C3' >> beam.Create([3])
pcoll4 = p | 'C4' >> beam.Create([4])
class MockDeferred(DeferredBase):
def __init__(self, pcoll):
mock_expr = MagicMock()
super().__init__(mock_expr)
self._pcoll = pcoll
def _get_underlying_pcollection(self):
return self._pcoll
deferred_pcoll = MockDeferred(pcoll4)
mock_deferred_to_pcoll.return_value = (pcoll4, p)
ib.watch(locals())
self.env.track_user_pipelines()
with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
mock_rm = MagicMock()
mock_get_rm.return_value = mock_rm
ib.compute(pcoll1, [pcoll2], {'a': pcoll3}, deferred_pcoll)
expected_pcolls = {pcoll1, pcoll2, pcoll3, pcoll4}
mock_rm.compute_async.assert_called_once_with(
expected_pcolls,
wait_for_inputs=True,
blocking=False,
runner=None,
options=None,
force_compute=False)
def test_compute_invalid_input_type(self):
with self.assertRaisesRegex(ValueError,
"not a dict, an iterable or a PCollection"):
ib.compute(123)
def test_compute_mixed_pipelines(self):
p1 = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p1 | 'C1' >> beam.Create([1])
p2 = beam.Pipeline(ir.InteractiveRunner())
pcoll2 = p2 | 'C2' >> beam.Create([2])
ib.watch(locals())
self.env.track_user_pipelines()
with self.assertRaisesRegex(
ValueError, "All PCollections must belong to the same pipeline"):
ib.compute(pcoll1, pcoll2)
@patch(
'apache_beam.runners.interactive.interactive_beam.'
'deferred_df_to_pcollection')
@patch.object(ib, 'watch')
def test_compute_with_deferred_base(self, mock_watch, mock_deferred_to_pcoll):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll = p | 'C1' >> beam.Create([1])
class MockDeferred(DeferredBase):
def __init__(self, pcoll):
# Provide a dummy expression to satisfy DeferredBase.__init__
mock_expr = MagicMock()
super().__init__(mock_expr)
self._pcoll = pcoll
def _get_underlying_pcollection(self):
return self._pcoll
deferred = MockDeferred(pcoll)
mock_deferred_to_pcoll.return_value = (pcoll, p)
with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
mock_rm = MagicMock()
mock_get_rm.return_value = mock_rm
ib.compute(deferred)
mock_deferred_to_pcoll.assert_called_once_with(deferred)
self.assertEqual(mock_watch.call_count, 2)
mock_watch.assert_has_calls([
call({f'anonymous_pcollection_{id(pcoll)}': pcoll}),
call({f'anonymous_pipeline_{id(p)}': p})
],
any_order=False)
mock_rm.compute_async.assert_called_once_with({pcoll},
wait_for_inputs=True,
blocking=False,
runner=None,
options=None,
force_compute=False)
def test_compute_new_pipeline(self):
p = beam.Pipeline(ir.InteractiveRunner())
pcoll = p | 'Create' >> beam.Create([1])
# NOT calling ib.watch() or track_user_pipelines()
with patch.object(self.env, 'get_recording_manager') as mock_get_rm, \
patch.object(ib, 'watch') as mock_watch:
mock_rm = MagicMock()
mock_get_rm.return_value = mock_rm
ib.compute(pcoll)
mock_watch.assert_called_with({f'anonymous_pipeline_{id(p)}': p})
mock_get_rm.assert_called_once_with(p, create_if_absent=True)
mock_rm.compute_async.assert_called_once()
if __name__ == '__main__':
unittest.main()