blob: 28ba7c8995c26dfb68f7f680154e302e49d27c17 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Test for Shared class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import threading
import time
import unittest
from apache_beam.utils import shared
class Count(object):
def __init__(self):
self._lock = threading.Lock()
self._total = 0
self._active = 0
def add_ref(self):
with self._lock:
self._total += 1
self._active += 1
def release_ref(self):
with self._lock:
self._active -= 1
def get_active(self):
with self._lock:
return self._active
def get_total(self):
with self._lock:
return self._total
class Marker(object):
def __init__(self, count):
self._count = count
self._count.add_ref()
def __del__(self):
self._count.release_ref()
class NamedObject(object):
def __init__(self, name):
self._name = name
def get_name(self):
return self._name
class Sequence(object):
def __init__(self):
self._sequence = 0
def make_acquire_fn(self):
# Every time acquire_fn is called, increases the sequence number and returns
# a NamedObject with that sequenece number.
def acquire_fn():
self._sequence += 1
return NamedObject('sequence%d' % self._sequence)
return acquire_fn
class SharedTest(unittest.TestCase):
def testKeepalive(self):
count = Count()
shared_handle = shared.Shared()
other_shared_handle = shared.Shared()
def dummy_acquire_fn():
return None
def acquire_fn():
return Marker(count)
p1 = shared_handle.acquire(acquire_fn)
self.assertEqual(1, count.get_total())
self.assertEqual(1, count.get_active())
del p1
gc.collect()
# Won't be garbage collected, because of the keep-alive
self.assertEqual(1, count.get_active())
# Reacquire.
p2 = shared_handle.acquire(acquire_fn)
self.assertEqual(1, count.get_total()) # No reinitialisation.
self.assertEqual(1, count.get_active())
# Get rid of the keepalive
other_shared_handle.acquire(dummy_acquire_fn)
del p2
gc.collect()
self.assertEqual(0, count.get_active())
def testMultiple(self):
count = Count()
shared_handle = shared.Shared()
other_shared_handle = shared.Shared()
def dummy_acquire_fn():
return None
def acquire_fn():
return Marker(count)
p = shared_handle.acquire(acquire_fn)
other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive
self.assertEqual(1, count.get_total())
self.assertEqual(1, count.get_active())
del p
gc.collect()
self.assertEqual(0, count.get_active())
# Shared value should be garbage collected.
# Acquiring multiple times only results in one initialisation
p1 = shared_handle.acquire(acquire_fn)
# Since shared value was released, expect a reinitialisation.
self.assertEqual(2, count.get_total())
self.assertEqual(1, count.get_active())
p2 = shared_handle.acquire(acquire_fn)
self.assertEqual(2, count.get_total())
self.assertEqual(1, count.get_active())
other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive
# Check that shared object isn't destroyed if there's still a reference to
# it.
del p2
gc.collect()
self.assertEqual(1, count.get_active())
del p1
gc.collect()
self.assertEqual(0, count.get_active())
def testConcurrentCallsDeduped(self):
# Test that only one among many calls to acquire will actually run the
# initialisation function.
count = Count()
shared_handle = shared.Shared()
other_shared_handle = shared.Shared()
refs = []
ref_lock = threading.Lock()
def dummy_acquire_fn():
return None
def acquire_fn():
time.sleep(1)
return Marker(count)
def thread_fn():
p = shared_handle.acquire(acquire_fn)
with ref_lock:
refs.append(p)
threads = []
for _ in range(100):
t = threading.Thread(target=thread_fn)
threads.append(t)
t.start()
for t in threads:
t.join()
self.assertEqual(1, count.get_total())
self.assertEqual(1, count.get_active())
other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive
with ref_lock:
del refs[:]
gc.collect()
self.assertEqual(0, count.get_active())
def testDifferentObjects(self):
sequence = Sequence()
def dummy_acquire_fn():
return None
first_handle = shared.Shared()
second_handle = shared.Shared()
dummy_handle = shared.Shared()
f1 = first_handle.acquire(sequence.make_acquire_fn())
s1 = second_handle.acquire(sequence.make_acquire_fn())
self.assertEqual('sequence1', f1.get_name())
self.assertEqual('sequence2', s1.get_name())
f2 = first_handle.acquire(sequence.make_acquire_fn())
s2 = second_handle.acquire(sequence.make_acquire_fn())
# Check that the repeated acquisitions return the earlier objects
self.assertEqual('sequence1', f2.get_name())
self.assertEqual('sequence2', s2.get_name())
# Release all references and force garbage-collection
del f1
del f2
del s1
del s2
dummy_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive
gc.collect()
# Check that acquiring again after they're released gives new objects
f3 = first_handle.acquire(sequence.make_acquire_fn())
s3 = second_handle.acquire(sequence.make_acquire_fn())
self.assertEqual('sequence3', f3.get_name())
self.assertEqual('sequence4', s3.get_name())
def testTagCacheEviction(self):
shared1 = shared.Shared()
shared2 = shared.Shared()
def acquire_fn_1():
return NamedObject('obj_1')
def acquire_fn_2():
return NamedObject('obj_2')
# with no tag, shared handle does not know when to evict objects
p1 = shared1.acquire(acquire_fn_1)
assert p1.get_name() == 'obj_1'
p2 = shared1.acquire(acquire_fn_2)
assert p2.get_name() == 'obj_1'
# cache eviction can be forced by specifying different tags
p1 = shared2.acquire(acquire_fn_1, tag='1')
assert p1.get_name() == 'obj_1'
p2 = shared2.acquire(acquire_fn_2, tag='2')
assert p2.get_name() == 'obj_2'
if __name__ == '__main__':
unittest.main()