blob: 3c74903b8d9915f7ccb4387a2e290ac3b7f9488b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import logging
import multiprocessing
import os
import tempfile
import threading
import unittest
from typing import Any
from apache_beam.utils import multi_process_shared
class CallableCounter(object):
def __init__(self, start=0):
self.running = start
self.lock = threading.Lock()
def __call__(self):
return self.running
def increment(self, value=1):
with self.lock:
self.running += value
return self.running
def error(self, msg):
raise RuntimeError(msg)
class Counter(object):
def __init__(self, start=0):
self.running = start
self.lock = threading.Lock()
def get(self):
return self.running
def increment(self, value=1):
with self.lock:
self.running += value
return self.running
def error(self, msg):
raise RuntimeError(msg)
class CounterWithBadAttr(object):
def __init__(self, start=0):
self.running = start
self.lock = threading.Lock()
def get(self):
return self.running
def increment(self, value=1):
with self.lock:
self.running += value
return self.running
def error(self, msg):
raise RuntimeError(msg)
def __getattribute__(self, __name: str) -> Any:
if __name == 'error':
raise AttributeError('error is not actually supported on this platform')
else:
# Default behaviour
return object.__getattribute__(self, __name)
class MultiProcessSharedTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.shared = multi_process_shared.MultiProcessShared(
Counter, tag='basic', always_proxy=True).acquire()
cls.sharedCallable = multi_process_shared.MultiProcessShared(
CallableCounter, tag='callable', always_proxy=True).acquire()
def test_call(self):
self.assertEqual(self.shared.get(), 0)
self.assertEqual(self.shared.increment(), 1)
self.assertEqual(self.shared.increment(10), 11)
self.assertEqual(self.shared.increment(value=10), 21)
self.assertEqual(self.shared.get(), 21)
def test_call_illegal_attr(self):
shared_handle = multi_process_shared.MultiProcessShared(
CounterWithBadAttr, tag='test_call_illegal_attr', always_proxy=True)
shared = shared_handle.acquire()
self.assertEqual(shared.get(), 0)
self.assertEqual(shared.increment(), 1)
self.assertEqual(shared.get(), 1)
def test_call_callable(self):
self.assertEqual(self.sharedCallable(), 0)
self.assertEqual(self.sharedCallable.increment(), 1)
self.assertEqual(self.sharedCallable.increment(10), 11)
self.assertEqual(self.sharedCallable.increment(value=10), 21)
self.assertEqual(self.sharedCallable(), 21)
def test_error(self):
with self.assertRaisesRegex(Exception, 'something bad'):
self.shared.error('something bad')
def test_no_method(self):
with self.assertRaisesRegex(Exception, 'no_such_method'):
self.shared.no_such_method()
def test_connect(self):
first = multi_process_shared.MultiProcessShared(
Counter, tag='counter').acquire()
second = multi_process_shared.MultiProcessShared(
Counter, tag='counter').acquire()
self.assertEqual(first.get(), 0)
self.assertEqual(first.increment(), 1)
self.assertEqual(second.get(), 1)
self.assertEqual(second.increment(), 2)
self.assertEqual(first.get(), 2)
self.assertEqual(first.increment(), 3)
def test_release(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_release')
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_release')
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
counter1again = shared1.acquire()
self.assertEqual(counter1again.increment(), 3)
shared1.release(counter1)
shared2.release(counter2)
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()
with self.assertRaisesRegex(Exception, 'released'):
counter2.get()
self.assertEqual(counter1again.get(), 3)
shared1.release(counter1again)
counter1New = shared1.acquire()
self.assertEqual(counter1New.get(), 0)
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()
def test_unsafe_hard_delete(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
try:
multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()
except Exception:
pass
with self.assertRaises(Exception):
counter1.get()
with self.assertRaises(Exception):
counter2.get()
shared3 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)
counter3 = shared3.acquire()
self.assertEqual(counter3.increment(), 1)
def test_unsafe_hard_delete_autoproxywrapper(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter,
tag='test_unsafe_hard_delete_autoproxywrapper',
always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter,
tag='test_unsafe_hard_delete_autoproxywrapper',
always_proxy=True)
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
try:
counter2.singletonProxy_unsafe_hard_delete()
except Exception:
pass
with self.assertRaises(Exception):
counter1.get()
with self.assertRaises(Exception):
counter2.get()
counter3 = multi_process_shared.MultiProcessShared(
Counter,
tag='test_unsafe_hard_delete_autoproxywrapper',
always_proxy=True).acquire()
self.assertEqual(counter3.increment(), 1)
def test_unsafe_hard_delete_no_op(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
try:
multi_process_shared.MultiProcessShared(
Counter, tag='no_tag_to_delete').unsafe_hard_delete()
except Exception:
pass
self.assertEqual(counter1.increment(), 3)
self.assertEqual(counter2.increment(), 4)
def test_release_always_proxy(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_release_always_proxy', always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_release_always_proxy', always_proxy=True)
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
counter1again = shared1.acquire()
self.assertEqual(counter1again.increment(), 3)
shared1.release(counter1)
shared2.release(counter2)
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()
with self.assertRaisesRegex(Exception, 'released'):
counter2.get()
self.assertEqual(counter1again.get(), 3)
shared1.release(counter1again)
counter1New = shared1.acquire()
self.assertEqual(counter1New.get(), 0)
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()
class MultiProcessSharedSpawnProcessTest(unittest.TestCase):
def setUp(self):
tempdir = tempfile.gettempdir()
for tag in ['basic',
'main',
'to_delete',
'to_keep',
'mix1',
'mix2',
'test_process_exit',
'thundering_herd_test']:
for ext in ['', '.address', '.address.error']:
try:
os.remove(os.path.join(tempdir, tag + ext))
except OSError:
pass
def tearDown(self):
for p in multiprocessing.active_children():
if p.is_alive():
try:
p.kill()
p.join(timeout=1.0)
except Exception:
pass
def test_call(self):
shared = multi_process_shared.MultiProcessShared(
Counter, tag='main', always_proxy=True, spawn_process=True).acquire()
self.assertEqual(shared.get(), 0)
self.assertEqual(shared.increment(), 1)
self.assertEqual(shared.increment(10), 11)
self.assertEqual(shared.increment(value=10), 21)
self.assertEqual(shared.get(), 21)
def test_unsafe_hard_delete_autoproxywrapper(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='to_delete', always_proxy=True, spawn_process=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='to_delete', always_proxy=True, spawn_process=True)
counter3 = multi_process_shared.MultiProcessShared(
Counter, tag='to_keep', always_proxy=True,
spawn_process=True).acquire()
counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
try:
counter2.singletonProxy_unsafe_hard_delete()
except Exception:
pass
with self.assertRaises(Exception):
counter1.get()
with self.assertRaises(Exception):
counter2.get()
counter4 = multi_process_shared.MultiProcessShared(
Counter, tag='to_delete', always_proxy=True,
spawn_process=True).acquire()
self.assertEqual(counter3.increment(), 1)
self.assertEqual(counter4.increment(), 1)
def test_mix_usage(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire()
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire()
self.assertEqual(shared1.get(), 0)
self.assertEqual(shared1.increment(), 1)
self.assertEqual(shared2.get(), 0)
self.assertEqual(shared2.increment(), 1)
def test_process_exits_on_unsafe_hard_delete(self):
shared = multi_process_shared.MultiProcessShared(
Counter, tag='test_process_exit', always_proxy=True, spawn_process=True)
obj = shared.acquire()
self.assertEqual(obj.increment(), 1)
children = multiprocessing.active_children()
server_process = None
for p in children:
if p.pid != os.getpid() and p.is_alive():
server_process = p
break
self.assertIsNotNone(
server_process, "Could not find spawned server process")
try:
obj.singletonProxy_unsafe_hard_delete()
except Exception:
pass
server_process.join(timeout=5)
self.assertFalse(
server_process.is_alive(),
f"Server process {server_process.pid} is still alive after hard delete")
self.assertIsNotNone(
server_process.exitcode, "Process has no exit code (did not exit)")
with self.assertRaises(Exception):
obj.get()
def test_process_exits_on_unsafe_hard_delete_with_manager(self):
shared = multi_process_shared.MultiProcessShared(
Counter, tag='test_process_exit', always_proxy=True, spawn_process=True)
obj = shared.acquire()
self.assertEqual(obj.increment(), 1)
children = multiprocessing.active_children()
server_process = None
for p in children:
if p.pid != os.getpid() and p.is_alive():
server_process = p
break
self.assertIsNotNone(
server_process, "Could not find spawned server process")
try:
shared.unsafe_hard_delete()
except Exception:
pass
server_process.join(timeout=5)
self.assertFalse(
server_process.is_alive(),
f"Server process {server_process.pid} is still alive after hard delete")
self.assertIsNotNone(
server_process.exitcode, "Process has no exit code (did not exit)")
with self.assertRaises(Exception):
obj.get()
def test_zombie_reaping_on_acquire(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True)
obj = shared1.acquire()
children = multiprocessing.active_children()
server_pid = next(
p.pid for p in children if p.is_alive() and p.pid != os.getpid())
try:
obj.singletonProxy_unsafe_hard_delete()
except Exception:
pass
try:
os.kill(server_pid, 0)
is_zombie = True
except OSError:
is_zombie = False
self.assertTrue(
is_zombie,
f"Server process {server_pid} was reaped too early before acquire()")
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True)
_ = shared2.acquire()
# If reaping worked, our old server_pid should NOT be in this list.
current_children_pids = [p.pid for p in multiprocessing.active_children()]
self.assertNotIn(
server_pid,
current_children_pids,
f"Old server process {server_pid} was not reaped by acquire() sweep")
try:
shared2.unsafe_hard_delete()
except Exception:
pass
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()