| # Copyright 2023 The casbin Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import os |
| import unittest |
| from unittest import IsolatedAsyncioTestCase |
| |
| import casbin |
| from sqlalchemy import Column, Integer, String, select, func |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker |
| |
| from casbin_async_sqlalchemy_adapter import Adapter |
| from casbin_async_sqlalchemy_adapter import Base |
| from casbin_async_sqlalchemy_adapter import CasbinRule |
| from casbin_async_sqlalchemy_adapter.adapter import Filter |
| |
| |
| def get_fixture(path): |
| dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" |
| return os.path.abspath(dir_path + path) |
| |
| |
| async def get_enforcer(): |
| engine = create_async_engine("sqlite+aiosqlite://", future=True) |
| # engine = create_async_engine('sqlite+aiosqlite:///' + os.path.split(os.path.realpath(__file__))[0] + '/test.db', |
| # echo=True) |
| adapter = Adapter(engine) |
| await adapter.create_table() |
| |
| async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
| async with async_session() as s: |
| s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")) |
| s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write")) |
| s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="read")) |
| s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="write")) |
| s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin")) |
| await s.commit() |
| |
| e = casbin.AsyncEnforcer(get_fixture("rbac_model.conf"), adapter) |
| await e.load_policy() |
| return e |
| |
| |
| class TestConfig(IsolatedAsyncioTestCase): |
| async def test_custom_db_class(self): |
| class CustomRule(Base): |
| __tablename__ = "casbin_rule2" |
| |
| id = Column(Integer, primary_key=True) |
| ptype = Column(String(255)) |
| v0 = Column(String(255)) |
| v1 = Column(String(255)) |
| v2 = Column(String(255)) |
| v3 = Column(String(255)) |
| v4 = Column(String(255)) |
| v5 = Column(String(255)) |
| not_exist = Column(String(255)) |
| |
| engine = create_async_engine("sqlite+aiosqlite://", future=True) |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
| |
| session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
| async with session() as s: |
| s.add(CustomRule(not_exist="NotNone")) |
| await s.commit() |
| a = await s.execute(select(CustomRule)) |
| self.assertEqual(a.scalars().all()[0].not_exist, "NotNone") |
| |
| async def test_enforcer_basic(self): |
| e = await get_enforcer() |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("alice", "data2", "read")) |
| self.assertTrue(e.enforce("alice", "data2", "write")) |
| |
| async def test_add_policy(self): |
| e = await get_enforcer() |
| |
| self.assertFalse(e.enforce("eve", "data3", "read")) |
| res = await e.add_policies((("eve", "data3", "read"), ("eve", "data4", "read"))) |
| self.assertTrue(res) |
| self.assertTrue(e.enforce("eve", "data3", "read")) |
| self.assertTrue(e.enforce("eve", "data4", "read")) |
| |
| async def test_add_policies(self): |
| e = await get_enforcer() |
| |
| self.assertFalse(e.enforce("eve", "data3", "read")) |
| res = await e.add_permission_for_user("eve", "data3", "read") |
| self.assertTrue(res) |
| self.assertTrue(e.enforce("eve", "data3", "read")) |
| |
| async def test_save_policy(self): |
| e = await get_enforcer() |
| self.assertFalse(e.enforce("alice", "data4", "read")) |
| |
| model = e.get_model() |
| model.clear_policy() |
| |
| model.add_policy("p", "p", ["alice", "data4", "read"]) |
| |
| adapter = e.get_adapter() |
| await adapter.save_policy(model) |
| self.assertTrue(e.enforce("alice", "data4", "read")) |
| |
| async def test_remove_policy(self): |
| e = await get_enforcer() |
| |
| self.assertFalse(e.enforce("alice", "data5", "read")) |
| await e.add_permission_for_user("alice", "data5", "read") |
| self.assertTrue(e.enforce("alice", "data5", "read")) |
| await e.delete_permission_for_user("alice", "data5", "read") |
| self.assertFalse(e.enforce("alice", "data5", "read")) |
| |
| async def test_remove_policies(self): |
| e = await get_enforcer() |
| |
| self.assertFalse(e.enforce("alice", "data5", "read")) |
| self.assertFalse(e.enforce("alice", "data6", "read")) |
| await e.add_policies((("alice", "data5", "read"), ("alice", "data6", "read"))) |
| self.assertTrue(e.enforce("alice", "data5", "read")) |
| self.assertTrue(e.enforce("alice", "data6", "read")) |
| await e.remove_policies((("alice", "data5", "read"), ("alice", "data6", "read"))) |
| self.assertFalse(e.enforce("alice", "data5", "read")) |
| self.assertFalse(e.enforce("alice", "data6", "read")) |
| |
| async def test_remove_filtered_policy(self): |
| e = await get_enforcer() |
| |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| await e.remove_filtered_policy(1, "data1") |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertTrue(e.enforce("alice", "data2", "read")) |
| self.assertTrue(e.enforce("alice", "data2", "write")) |
| |
| await e.remove_filtered_policy(1, "data2", "read") |
| |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertTrue(e.enforce("alice", "data2", "write")) |
| |
| await e.remove_filtered_policy(2, "write") |
| |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| |
| async def test_str(self): |
| rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") |
| self.assertEqual(str(rule), "p, alice, data1, read") |
| rule = CasbinRule(ptype="p", v0="bob", v1="data2", v2="write") |
| self.assertEqual(str(rule), "p, bob, data2, write") |
| rule = CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="read") |
| self.assertEqual(str(rule), "p, data2_admin, data2, read") |
| rule = CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="write") |
| self.assertEqual(str(rule), "p, data2_admin, data2, write") |
| rule = CasbinRule(ptype="g", v0="alice", v1="data2_admin") |
| self.assertEqual(str(rule), "g, alice, data2_admin") |
| |
| async def test_repr(self): |
| rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") |
| self.assertEqual(repr(rule), '<CasbinRule None: "p, alice, data1, read">') |
| engine = create_async_engine("sqlite+aiosqlite://", future=True) |
| |
| session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
| s = session() |
| |
| s.add(rule) |
| await s.commit() |
| self.assertRegex(repr(rule), r'<CasbinRule \d+: "p, alice, data1, read">') |
| await s.close() |
| |
| async def test_filtered_policy(self): |
| e = await get_enforcer() |
| filter = Filter() |
| |
| filter.ptype = ["p"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| |
| filter.ptype = [] |
| filter.v0 = ["alice"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v0 = ["bob"] |
| await e.load_filtered_policy(filter) |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v0 = ["data2_admin"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("data2_admin", "data2", "read")) |
| self.assertTrue(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| |
| filter.v0 = ["alice", "bob"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v0 = [] |
| filter.v1 = ["data1"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v1 = ["data2"] |
| await e.load_filtered_policy(filter) |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertTrue(e.enforce("data2_admin", "data2", "read")) |
| self.assertTrue(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v1 = [] |
| filter.v2 = ["read"] |
| await e.load_filtered_policy(filter) |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| self.assertTrue(e.enforce("data2_admin", "data2", "read")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| |
| filter.v2 = ["write"] |
| await e.load_filtered_policy(filter) |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| self.assertFalse(e.enforce("alice", "data1", "write")) |
| self.assertFalse(e.enforce("alice", "data2", "read")) |
| self.assertFalse(e.enforce("alice", "data2", "write")) |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| self.assertFalse(e.enforce("bob", "data2", "read")) |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertTrue(e.enforce("data2_admin", "data2", "write")) |
| |
| async def test_update_policy(self): |
| e = await get_enforcer() |
| example_p = ["mike", "cookie", "eat"] |
| |
| self.assertTrue(e.enforce("alice", "data1", "read")) |
| await e.update_policy(["alice", "data1", "read"], ["alice", "data1", "no_read"]) |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| |
| self.assertFalse(e.enforce("bob", "data1", "read")) |
| await e.add_policy(example_p) |
| await e.update_policy(example_p, ["bob", "data1", "read"]) |
| self.assertTrue(e.enforce("bob", "data1", "read")) |
| |
| self.assertFalse(e.enforce("bob", "data1", "write")) |
| await e.update_policy(["bob", "data1", "read"], ["bob", "data1", "write"]) |
| self.assertTrue(e.enforce("bob", "data1", "write")) |
| |
| self.assertTrue(e.enforce("bob", "data2", "write")) |
| await e.update_policy(["bob", "data2", "write"], ["bob", "data2", "read"]) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| |
| self.assertTrue(e.enforce("bob", "data2", "read")) |
| await e.update_policy(["bob", "data2", "read"], ["carl", "data2", "write"]) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| |
| self.assertTrue(e.enforce("carl", "data2", "write")) |
| await e.update_policy(["carl", "data2", "write"], ["carl", "data2", "no_write"]) |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| |
| async def test_update_policies(self): |
| e = await get_enforcer() |
| |
| old_rule_0 = ["alice", "data1", "read"] |
| old_rule_1 = ["bob", "data2", "write"] |
| old_rule_2 = ["data2_admin", "data2", "read"] |
| old_rule_3 = ["data2_admin", "data2", "write"] |
| |
| new_rule_0 = ["alice", "data_test", "read"] |
| new_rule_1 = ["bob", "data_test", "write"] |
| new_rule_2 = ["data2_admin", "data_test", "read"] |
| new_rule_3 = ["data2_admin", "data_test", "write"] |
| |
| old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3] |
| new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3] |
| |
| await e.update_policies(old_rules, new_rules) |
| |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| self.assertTrue(e.enforce("alice", "data_test", "read")) |
| |
| self.assertFalse(e.enforce("bob", "data2", "write")) |
| self.assertTrue(e.enforce("bob", "data_test", "write")) |
| |
| self.assertFalse(e.enforce("data2_admin", "data2", "read")) |
| self.assertTrue(e.enforce("data2_admin", "data_test", "read")) |
| |
| self.assertFalse(e.enforce("data2_admin", "data2", "write")) |
| self.assertTrue(e.enforce("data2_admin", "data_test", "write")) |
| |
| async def test_update_filtered_policies(self): |
| e = await get_enforcer() |
| |
| await e.update_filtered_policies( |
| [ |
| ["data2_admin", "data3", "read"], |
| ["data2_admin", "data3", "write"], |
| ], |
| 0, |
| "data2_admin", |
| ) |
| self.assertTrue(e.enforce("data2_admin", "data3", "write")) |
| self.assertTrue(e.enforce("data2_admin", "data3", "read")) |
| |
| await e.update_filtered_policies([["alice", "data1", "write"]], 0, "alice") |
| self.assertTrue(e.enforce("alice", "data1", "write")) |
| |
| await e.update_filtered_policies([["bob", "data2", "read"]], 0, "bob") |
| self.assertTrue(e.enforce("bob", "data2", "read")) |
| |
| async def test_clear_policy(self): |
| """Test that clear_policy() removes all records from the database.""" |
| e = await get_enforcer() |
| adapter = e.get_adapter() |
| engine = adapter._engine |
| |
| # Verify there are policies in the database |
| async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
| async with async_session() as s: |
| cnt = await s.execute(select(func.count()).select_from(CasbinRule)) |
| initial_count = cnt.scalar_one() |
| self.assertGreater(initial_count, 0, "There should be policies in the database before clearing") |
| |
| # Clear all policies from the database |
| await adapter.clear_policy() |
| |
| # Verify all policies are removed from the database |
| async with async_session() as s: |
| cnt = await s.execute(select(func.count()).select_from(CasbinRule)) |
| final_count = cnt.scalar_one() |
| self.assertEqual(final_count, 0, "All policies should be removed from the database") |
| |
| # Verify enforcer still works after clearing (can load empty policy) |
| await e.load_policy() |
| self.assertFalse(e.enforce("alice", "data1", "read")) |
| |
| # Verify we can add policies after clearing |
| await e.add_policy("eve", "data3", "read") |
| self.assertTrue(e.enforce("eve", "data3", "read")) |
| |
| |
| class TestBulkInsert(IsolatedAsyncioTestCase): |
| async def test_add_policies_bulk_internal_session(self): |
| engine = create_async_engine("sqlite+aiosqlite://", future=True) |
| adapter = Adapter(engine) |
| await adapter.create_table() |
| |
| rules = [ |
| ("u1", "obj1", "read"), |
| ("u2", "obj2", "write"), |
| ("u3", "obj3", "read"), |
| ] |
| await adapter.add_policies("p", "p", rules) |
| |
| async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
| async with async_session() as s: |
| # count inserted rows |
| from sqlalchemy import select, func |
| |
| cnt = await s.execute(select(func.count()).select_from(CasbinRule).where(CasbinRule.ptype == "p")) |
| assert cnt.scalar_one() == len(rules) |
| |
| rows = (await s.execute(select(CasbinRule).order_by(CasbinRule.id))).scalars().all() |
| tuples = [(r.v0, r.v1, r.v2) for r in rows] |
| for r in rules: |
| assert r in tuples |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |