| from contextlib import contextmanager |
| |
| import sqlalchemy |
| from casbin import persist |
| from sqlalchemy import Column, Integer, String |
| from sqlalchemy import create_engine, or_ |
| from sqlalchemy.orm import sessionmaker |
| |
| # declarative base class |
| if sqlalchemy.__version__.startswith("1."): |
| from sqlalchemy.orm import declarative_base |
| |
| Base = declarative_base() |
| |
| else: |
| from sqlalchemy.orm import DeclarativeBase |
| |
| class Base(DeclarativeBase): |
| pass |
| |
| |
| # Cache for CasbinRule classes by table name to avoid duplicate class warnings |
| _casbin_rule_cache = {} |
| |
| |
| def create_casbin_rule_class(table_name): |
| """ |
| Factory function to create a CasbinRule class with a custom table name. |
| |
| Args: |
| table_name (str): Table name for the CasbinRule class. |
| |
| Returns: |
| db_class (CasbinRule): The CasbinRule class. |
| """ |
| # Return cached class if it exists for this table name |
| if table_name in _casbin_rule_cache: |
| return _casbin_rule_cache[table_name] |
| |
| # Create a unique class name based on the table name to avoid SQLAlchemy warnings |
| # Convert table_name to a valid Python class name |
| class_name = "CasbinRule_" + "".join(c if c.isalnum() else "_" for c in table_name) |
| |
| # Dynamically create the class with a unique name |
| CasbinRule = type( |
| class_name, |
| (Base,), |
| { |
| "__tablename__": table_name, |
| "__table_args__": {"extend_existing": True}, |
| "id": Column(Integer, primary_key=True), |
| "ptype": Column(String(255)), |
| "v0": Column(String(255)), |
| "v1": Column(String(255)), |
| "v2": Column(String(255)), |
| "v3": Column(String(255)), |
| "v4": Column(String(255)), |
| "v5": Column(String(255)), |
| "__str__": lambda self: ", ".join( |
| [self.ptype] |
| + [ |
| v |
| for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5) |
| if v is not None |
| ] |
| ), |
| "__repr__": lambda self: '<CasbinRule {}: "{}">'.format(self.id, str(self)), |
| "__module__": "sqlalchemy_adapter.adapter", |
| }, |
| ) |
| |
| # Cache the class before returning |
| _casbin_rule_cache[table_name] = CasbinRule |
| return CasbinRule |
| |
| |
| # Export the default CasbinRule class with table name 'casbin_rule'. |
| CasbinRule = create_casbin_rule_class("casbin_rule") |
| |
| |
| class Filter: |
| ptype = [] |
| v0 = [] |
| v1 = [] |
| v2 = [] |
| v3 = [] |
| v4 = [] |
| v5 = [] |
| |
| |
| class Adapter(persist.Adapter, persist.adapters.UpdateAdapter): |
| """the interface for Casbin adapters.""" |
| |
| def __init__( |
| self, |
| engine, |
| db_class=None, |
| table_name="casbin_rule", |
| filtered=False, |
| create_table=True, |
| ): |
| if isinstance(engine, str): |
| self._engine = create_engine(engine) |
| else: |
| self._engine = engine |
| |
| if db_class is None: |
| db_class = create_casbin_rule_class(table_name=table_name) |
| metadata = Base.metadata |
| else: |
| for attr in ( |
| "id", |
| "ptype", |
| "v0", |
| "v1", |
| "v2", |
| "v3", |
| "v4", |
| "v5", |
| ): # id attr was used by filter |
| if not hasattr(db_class, attr): |
| raise Exception(f"{attr} not found in custom DatabaseClass.") |
| metadata = db_class.metadata |
| |
| self._db_class = db_class |
| self.session_local = sessionmaker(bind=self._engine) |
| |
| if create_table: |
| metadata.create_all(self._engine) |
| self._filtered = filtered |
| |
| @contextmanager |
| def _session_scope(self): |
| """Provide a transactional scope around a series of operations.""" |
| session = self.session_local() |
| try: |
| yield session |
| session.commit() |
| except Exception as e: |
| session.rollback() |
| raise e |
| finally: |
| session.close() |
| |
| def load_policy(self, model): |
| """loads all policy rules from the storage.""" |
| with self._session_scope() as session: |
| lines = session.query(self._db_class).all() |
| for line in lines: |
| persist.load_policy_line(str(line), model) |
| |
| def is_filtered(self): |
| return self._filtered |
| |
| def load_filtered_policy(self, model, filter) -> None: |
| """loads all policy rules from the storage.""" |
| with self._session_scope() as session: |
| query = session.query(self._db_class) |
| filters = self.filter_query(query, filter) |
| filters = filters.all() |
| |
| for line in filters: |
| persist.load_policy_line(str(line), model) |
| self._filtered = True |
| |
| def filter_query(self, querydb, filter): |
| for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"): |
| if len(getattr(filter, attr)) > 0: |
| querydb = querydb.filter( |
| getattr(self._db_class, attr).in_(getattr(filter, attr)) |
| ) |
| return querydb.order_by(self._db_class.id) |
| |
| def _save_policy_line(self, ptype, rule, session=None): |
| line = self._db_class(ptype=ptype) |
| for i, v in enumerate(rule): |
| setattr(line, "v{}".format(i), v) |
| if session: |
| session.add(line) |
| else: |
| with self._session_scope() as session: |
| session.add(line) |
| |
| def save_policy(self, model): |
| """saves all policy rules to the storage.""" |
| with self._session_scope() as session: |
| query = session.query(self._db_class) |
| query.delete() |
| for sec in ["p", "g"]: |
| if sec not in model.model.keys(): |
| continue |
| for ptype, ast in model.model[sec].items(): |
| for rule in ast.policy: |
| self._save_policy_line(ptype, rule, session=session) |
| return True |
| |
| def add_policy(self, sec, ptype, rule): |
| """adds a policy rule to the storage.""" |
| self._save_policy_line(ptype, rule) |
| |
| def add_policies(self, sec, ptype, rules): |
| """adds a policy rules to the storage.""" |
| for rule in rules: |
| self._save_policy_line(ptype, rule) |
| |
| def remove_policy(self, sec, ptype, rule): |
| """removes a policy rule from the storage.""" |
| with self._session_scope() as session: |
| query = session.query(self._db_class) |
| query = query.filter(self._db_class.ptype == ptype) |
| for i, v in enumerate(rule): |
| query = query.filter(getattr(self._db_class, "v{}".format(i)) == v) |
| r = query.delete() |
| |
| return True if r > 0 else False |
| |
| def remove_policies(self, sec, ptype, rules): |
| """remove policy rules from the storage.""" |
| if not rules: |
| return |
| with self._session_scope() as session: |
| query = session.query(self._db_class) |
| query = query.filter(self._db_class.ptype == ptype) |
| rules = zip(*rules) |
| for i, rule in enumerate(rules): |
| query = query.filter( |
| or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule) |
| ) |
| query.delete() |
| |
| def remove_filtered_policy(self, sec, ptype, field_index, *field_values): |
| """removes policy rules that match the filter from the storage. |
| This is part of the Auto-Save feature. |
| """ |
| with self._session_scope() as session: |
| query = session.query(self._db_class).filter(self._db_class.ptype == ptype) |
| |
| if not (0 <= field_index <= 5): |
| return False |
| if not (1 <= field_index + len(field_values) <= 6): |
| return False |
| for i, v in enumerate(field_values): |
| if v != "": |
| v_value = getattr(self._db_class, "v{}".format(field_index + i)) |
| query = query.filter(v_value == v) |
| r = query.delete() |
| |
| return True if r > 0 else False |
| |
| def update_policy( |
| self, sec: str, ptype: str, old_rule: [str], new_rule: [str] |
| ) -> None: |
| """ |
| Update the old_rule with the new_rule in the database (storage). |
| |
| :param sec: section type |
| :param ptype: policy type |
| :param old_rule: the old rule that needs to be modified |
| :param new_rule: the new rule to replace the old rule |
| |
| :return: None |
| """ |
| |
| with self._session_scope() as session: |
| query = session.query(self._db_class).filter(self._db_class.ptype == ptype) |
| |
| # locate the old rule |
| for index, value in enumerate(old_rule): |
| v_value = getattr(self._db_class, "v{}".format(index)) |
| query = query.filter(v_value == value) |
| |
| # need the length of the longest_rule to perform overwrite |
| longest_rule = old_rule if len(old_rule) > len(new_rule) else new_rule |
| old_rule_line = query.one() |
| |
| # overwrite the old rule with the new rule |
| for index in range(len(longest_rule)): |
| if index < len(new_rule): |
| exec(f"old_rule_line.v{index} = new_rule[{index}]") |
| else: |
| exec(f"old_rule_line.v{index} = None") |
| |
| def update_policies( |
| self, |
| sec: str, |
| ptype: str, |
| old_rules: [ |
| [str], |
| ], |
| new_rules: [ |
| [str], |
| ], |
| ) -> None: |
| """ |
| Update the old_rules with the new_rules in the database (storage). |
| |
| :param sec: section type |
| :param ptype: policy type |
| :param old_rules: the old rules that need to be modified |
| :param new_rules: the new rules to replace the old rules |
| |
| :return: None |
| """ |
| for i in range(len(old_rules)): |
| self.update_policy(sec, ptype, old_rules[i], new_rules[i]) |
| |
| def update_filtered_policies( |
| self, sec, ptype, new_rules: [[str]], field_index, *field_values |
| ) -> [[str]]: |
| """update_filtered_policies updates all the policies on the basis of the filter.""" |
| |
| filter = Filter() |
| filter.ptype = ptype |
| |
| # Creating Filter from the field_index & field_values provided |
| for i in range(len(field_values)): |
| if field_index <= i and i < field_index + len(field_values): |
| setattr(filter, f"v{i}", field_values[i - field_index]) |
| else: |
| break |
| |
| self._update_filtered_policies(new_rules, filter) |
| |
| def _update_filtered_policies(self, new_rules, filter) -> [[str]]: |
| """_update_filtered_policies updates all the policies on the basis of the filter.""" |
| |
| with self._session_scope() as session: |
| # Load old policies |
| |
| query = session.query(self._db_class).filter( |
| self._db_class.ptype == filter.ptype |
| ) |
| filtered_query = self.filter_query(query, filter) |
| old_rules = filtered_query.all() |
| |
| # Delete old policies |
| |
| self.remove_policies("p", filter.ptype, old_rules) |
| |
| # Insert new policies |
| |
| self.add_policies("p", filter.ptype, new_rules) |
| |
| # return deleted rules |
| |
| return old_rules |