blob: f7025b8b75bfdaa2c9f531221732fb0a9450c662 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import apache_beam as beam
from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker
class UserPipelineTrackerTest(unittest.TestCase):
def test_getting_unknown_pid_returns_none(self):
ut = UserPipelineTracker()
p = beam.Pipeline()
self.assertIsNone(ut.get_pipeline(str(id(p))))
def test_getting_unknown_pipeline_returns_none(self):
ut = UserPipelineTracker()
p = beam.Pipeline()
self.assertIsNone(ut.get_user_pipeline(p))
def test_no_parent_returns_none(self):
ut = UserPipelineTracker()
user = beam.Pipeline()
derived = beam.Pipeline()
orphan = beam.Pipeline()
ut.add_derived_pipeline(user, derived)
self.assertIsNone(ut.get_user_pipeline(orphan))
def test_get_user_pipeline_is_same(self):
ut = UserPipelineTracker()
p = beam.Pipeline()
ut.add_user_pipeline(p)
self.assertIs(ut.get_user_pipeline(p), p)
def test_can_add_derived(self):
ut = UserPipelineTracker()
user = beam.Pipeline()
derived = beam.Pipeline()
ut.add_derived_pipeline(user, derived)
self.assertIs(ut.get_user_pipeline(derived), user)
def test_can_add_multiple_derived(self):
"""Tests that there can be many user pipelines with many derived
pipelines.
"""
ut = UserPipelineTracker()
# Add the first set of user and derived pipelines.
user1 = beam.Pipeline()
derived11 = beam.Pipeline()
derived12 = beam.Pipeline()
ut.add_derived_pipeline(user1, derived11)
ut.add_derived_pipeline(user1, derived12)
# Add the second set of user and derived pipelines.
user2 = beam.Pipeline()
derived21 = beam.Pipeline()
derived22 = beam.Pipeline()
ut.add_derived_pipeline(user2, derived21)
ut.add_derived_pipeline(user2, derived22)
# Assert that the user pipelines are correct.
self.assertIs(ut.get_user_pipeline(derived11), user1)
self.assertIs(ut.get_user_pipeline(derived12), user1)
self.assertIs(ut.get_user_pipeline(derived21), user2)
self.assertIs(ut.get_user_pipeline(derived22), user2)
def test_cannot_have_multiple_parents(self):
ut = UserPipelineTracker()
user1 = beam.Pipeline()
user2 = beam.Pipeline()
derived = beam.Pipeline()
ut.add_derived_pipeline(user1, derived)
with self.assertRaises(AssertionError):
ut.add_derived_pipeline(user2, derived)
self.assertIs(ut.get_user_pipeline(derived), user1)
def test_adding_derived_with_derived_gets_user_pipeline(self):
"""Tests that one can correctly add a derived pipeline from a derived
pipeline and still get the correct user pipeline.
"""
ut = UserPipelineTracker()
user = beam.Pipeline()
derived1 = beam.Pipeline()
derived2 = beam.Pipeline()
# Add the first derived pipeline to the user pipelne.
ut.add_derived_pipeline(user, derived1)
# Add the second derived pipeline to the first derived pipeline. This should
# get the user pipeline of the first and add the second to it.
ut.add_derived_pipeline(derived1, derived2)
# Asserts that both derived pipelines are under the same user pipeline.
self.assertIs(ut.get_user_pipeline(derived1), user)
self.assertIs(ut.get_user_pipeline(derived2), user)
def test_can_get_pipeline_from_id(self):
"""Tests the pid -> pipeline memoization."""
ut = UserPipelineTracker()
user = beam.Pipeline()
derived = beam.Pipeline()
ut.add_user_pipeline(user)
ut.add_derived_pipeline(user, derived)
self.assertIs(ut.get_pipeline(str(id(user))), user)
self.assertIs(ut.get_pipeline(str(id(derived))), derived)
def test_clear(self):
ut = UserPipelineTracker()
user = beam.Pipeline()
derived = beam.Pipeline()
ut.add_derived_pipeline(user, derived)
self.assertIs(ut.get_user_pipeline(derived), user)
ut.clear()
self.assertIsNone(ut.get_user_pipeline(user))
self.assertIsNone(ut.get_user_pipeline(derived))
def test_can_iterate(self):
ut = UserPipelineTracker()
user1 = beam.Pipeline()
derived11 = beam.Pipeline()
derived12 = beam.Pipeline()
ut.add_derived_pipeline(user1, derived11)
ut.add_derived_pipeline(user1, derived12)
user2 = beam.Pipeline()
derived21 = beam.Pipeline()
derived22 = beam.Pipeline()
ut.add_derived_pipeline(user2, derived21)
ut.add_derived_pipeline(user2, derived22)
user_pipelines = set(p for p in ut)
self.assertSetEqual(set([user1, user2]), user_pipelines)
def test_can_evict_user_pipeline(self):
ut = UserPipelineTracker()
user1 = beam.Pipeline()
derived11 = beam.Pipeline()
derived12 = beam.Pipeline()
ut.add_derived_pipeline(user1, derived11)
ut.add_derived_pipeline(user1, derived12)
user2 = beam.Pipeline()
derived21 = beam.Pipeline()
derived22 = beam.Pipeline()
ut.add_derived_pipeline(user2, derived21)
ut.add_derived_pipeline(user2, derived22)
ut.evict(user1)
self.assertIsNone(ut.get_user_pipeline(user1))
self.assertIsNone(ut.get_user_pipeline(derived11))
self.assertIsNone(ut.get_user_pipeline(derived12))
self.assertIs(user2, ut.get_user_pipeline(derived21))
self.assertIs(user2, ut.get_user_pipeline(derived22))
if __name__ == '__main__':
unittest.main()