blob: 5311ccb2ce9fa45d1e0e7d5add77413174b31f43 [file] [log] [blame]
from logging import DEBUG, getLogger, StreamHandler
from multiprocessing import connection, context
from time import sleep
from unittest import TestCase, main
from unittest.mock import MagicMock
import sys
from postgresql_watcher import PostgresqlWatcher
from postgresql_watcher.casbin_channel_subscription import CASBIN_CHANNEL_SELECT_TIMEOUT
# Warning!!! , Please setup yourself config
HOST = "127.0.0.1"
PORT = 5432
USER = "postgres"
PASSWORD = "123456"
DBNAME = "postgres"
logger = getLogger()
logger.level = DEBUG
stream_handler = StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
def get_watcher(channel_name):
return PostgresqlWatcher(
host=HOST,
port=PORT,
user=USER,
password=PASSWORD,
dbname=DBNAME,
logger=logger,
channel_name=channel_name,
)
try:
import _winapi
from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE
except ImportError as e:
if sys.platform == "win32":
raise e
_winapi = None
class TestConfig(TestCase):
def test_pg_watcher_init(self):
pg_watcher = get_watcher("test_pg_watcher_init")
if _winapi:
assert isinstance(pg_watcher.parent_conn, connection.PipeConnection)
else:
assert isinstance(pg_watcher.parent_conn, connection.Connection)
assert isinstance(pg_watcher.subscription_process, context.Process)
def test_update_single_pg_watcher(self):
pg_watcher = get_watcher("test_update_single_pg_watcher")
pg_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertTrue(pg_watcher.should_reload())
def test_no_update_single_pg_watcher(self):
pg_watcher = get_watcher("test_no_update_single_pg_watcher")
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertFalse(pg_watcher.should_reload())
def test_update_mutiple_pg_watcher(self):
channel_name = "test_update_mutiple_pg_watcher"
main_watcher = get_watcher(channel_name)
other_watchers = [get_watcher(channel_name) for _ in range(5)]
main_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
for watcher in other_watchers:
self.assertTrue(watcher.should_reload())
def test_no_update_mutiple_pg_watcher(self):
channel_name = "test_no_update_mutiple_pg_watcher"
main_watcher = get_watcher(channel_name)
other_watchers = [get_watcher(channel_name) for _ in range(5)]
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
for watcher in other_watchers:
self.assertFalse(watcher.should_reload())
self.assertFalse(main_watcher.should_reload())
def test_update_handler_called(self):
channel_name = "test_update_handler_called"
main_watcher = get_watcher(channel_name)
handler = MagicMock()
main_watcher.set_update_callback(handler)
main_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertTrue(main_watcher.should_reload())
self.assertTrue(handler.call_count == 1)
def test_update_handler_called_multiple_channel_messages(self):
channel_name = "test_update_handler_called_multiple_channel_messages"
main_watcher = get_watcher(channel_name)
handler = MagicMock()
main_watcher.set_update_callback(handler)
number_of_updates = 5
for _ in range(number_of_updates):
main_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * (number_of_updates + 1))
while main_watcher.should_reload():
pass
self.assertTrue(handler.call_count == 1)
def test_update_handler_not_called(self):
channel_name = "test_update_handler_not_called"
main_watcher = get_watcher(channel_name)
handler = MagicMock()
main_watcher.set_update_callback(handler)
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertFalse(main_watcher.should_reload())
self.assertTrue(handler.call_count == 0)
def test_stop_and_restart(self):
channel_name = "test_stop_and_restart"
pg_watcher = get_watcher(channel_name)
# Verify initially started
self.assertTrue(pg_watcher.subscription_process.is_alive())
# Stop the watcher
pg_watcher.stop()
self.assertIsNone(pg_watcher.subscription_process)
# Restart the watcher
pg_watcher.start()
# Verify resources are recreated and process is alive
self.assertTrue(pg_watcher.subscription_process.is_alive())
# Verify it still works after restart
pg_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertTrue(pg_watcher.should_reload())
if __name__ == "__main__":
main()