blob: 36840877c54cdce71e13e9ba941c1a5c41fdd735 [file] [log] [blame]
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import casbin
from tests.test_enforcer import get_examples, TestCaseBase
from unittest import IsolatedAsyncioTestCase
class SampleWatcher:
def __init__(self):
self.callback = None
self.notify_message = None
def close(self):
pass
def set_update_callback(self, callback):
"""
sets the callback function to be called when the policy is updated
:param callable callback: callback(event)
- event: event received from the rabbitmq
:return:
"""
self.callback = callback
def update(self, msg):
"""
update the policy
"""
self.notify_message = msg
return True
def update_for_add_policy(self, section, ptype, *params):
"""
update for add policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called add policy"
return self.update(message)
def update_for_remove_policy(self, section, ptype, *params):
"""
update for remove policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called remove policy"
return self.update(message)
def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
"""
update for remove filtered policy
:param section: section
:param ptype: policy type
:param field_index: field index
:param params: other params
:return:
"""
message = "called remove filtered policy"
return self.update(message)
def update_for_save_policy(self, model: casbin.Model):
"""
update for save policy
:param model: casbin model
:return:
"""
message = "called save policy"
return self.update(message)
def update_for_add_policies(self, section, ptype, *params):
"""
update for add policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called add policies"
return self.update(message)
def update_for_remove_policies(self, section, ptype, *params):
"""
update for remove policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called remove policies"
return self.update(message)
def start_watch(self):
"""
starts the watch thread
:return:
"""
pass
class AsyncSampleWatcher:
def __init__(self):
self.callback = None
self.notify_message = None
async def close(self):
pass
async def set_update_callback(self, callback):
"""
sets the callback function to be called when the policy is updated
:param callable callback: callback(event)
- event: event received from the rabbitmq
:return:
"""
self.callback = callback
async def update(self, msg):
"""
update the policy
"""
self.notify_message = msg
return True
async def update_for_add_policy(self, section, ptype, *params):
"""
update for add policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called add policy"
return await self.update(message)
async def update_for_remove_policy(self, section, ptype, *params):
"""
update for remove policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called remove policy"
return await self.update(message)
async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
"""
update for remove filtered policy
:param section: section
:param ptype: policy type
:param field_index: field index
:param params: other params
:return:
"""
message = "called remove filtered policy"
return await self.update(message)
async def update_for_save_policy(self, model: casbin.Model):
"""
update for save policy
:param model: casbin model
:return:
"""
message = "called save policy"
return await self.update(message)
async def update_for_add_policies(self, section, ptype, *params):
"""
update for add policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called add policies"
return await self.update(message)
async def update_for_remove_policies(self, section, ptype, *params):
"""
update for remove policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called remove policies"
return await self.update(message)
async def start_watch(self):
"""
starts the watch thread
:return:
"""
pass
class TestWatcherEx(TestCaseBase):
def get_enforcer(self, model=None, adapter=None):
return casbin.Enforcer(
model,
adapter,
)
def test_auto_notify_enabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
w = SampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(True)
e.save_policy()
self.assertEqual(w.notify_message, "called save policy")
e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called add policy")
e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called remove policy")
e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, "called remove filtered policy")
rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
e.add_policies(rules)
self.assertEqual(w.notify_message, "called add policies")
e.remove_policies(rules)
self.assertEqual(w.notify_message, "called remove policies")
def test_auto_notify_disabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
w = SampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(False)
e.save_policy()
self.assertEqual(w.notify_message, "called save policy")
w.notify_message = None
e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)
e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)
e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, None)
rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
e.add_policies(rules)
self.assertEqual(w.notify_message, None)
e.remove_policies(rules)
self.assertEqual(w.notify_message, None)
class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
def get_enforcer(self, model=None, adapter=None):
return casbin.AsyncEnforcer(
model,
adapter,
)
async def test_auto_notify_enabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()
w = AsyncSampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(True)
await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")
await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called add policy")
await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called remove policy")
await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, "called remove filtered policy")
rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, "called add policies")
await e.remove_policies(rules)
self.assertEqual(w.notify_message, "called remove policies")
async def test_auto_notify_disabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()
w = SampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(False)
await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")
w.notify_message = None
await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)
await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)
await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, None)
rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, None)
await e.remove_policies(rules)
self.assertEqual(w.notify_message, None)