blob: 0516db6fd305a87ac8db7ca79a17b338ae227aeb [file] [log] [blame]
from typing import List, Dict
from casbin import persist, Model
from databases import Database
from sqlalchemy import Table
class Filter:
ptype: List[str] = []
v0: List[str] = []
v1: List[str] = []
v2: List[str] = []
v3: List[str] = []
v4: List[str] = []
v5: List[str] = []
class DatabasesAdapter(persist.Adapter):
cols = ["ptype"] + [f"v{i}" for i in range(6)]
def __init__(self, db: Database, table: Table, filtered=False):
self.db: Database = db
self.table: Table = table
self.filtered: bool = filtered
async def load_policy(self, model: Model):
query = self.table.select()
rows = await self.db.fetch_all(query)
for row in rows:
# convert row from tuple to csv format and removing the first column (id)
line = [
v
for k, v in dict(row._mapping).items()
if k in self.cols and v is not None
]
persist.load_policy_line(", ".join(line), model)
async def save_policy(self, model: Model):
await self.db.execute(self.table.delete())
query = self.table.insert()
values: List = []
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for p_type, assertion in model.model[sec].items():
for rule in assertion.policy:
row = self._policy_to_dict(p_type, rule)
values.append(row)
await self.db.execute_many(query, values)
return True
async def add_policy(self, sec, p_type, rule):
row = self._policy_to_dict(p_type, rule)
await self.db.execute(self.table.insert(), row)
async def remove_policy(self, sec, p_type, rule):
query = self.table.delete().where(self.table.columns.ptype == p_type)
for i, value in enumerate(rule):
query = query.where(self.table.columns[f"v{i}"] == value)
result = await self.db.execute(query)
return True if result > 0 else False
async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
query = self.table.delete().where(self.table.columns.ptype == ptype)
if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, value in enumerate(field_values):
if len(value) > 0:
query = query.where(self.table.columns[f"v{field_index+i}"] == value)
result = await self.db.execute(query)
return True if result else False
async def load_filtered_policy(self, model: Model, filter_: Filter) -> None:
query = self.table.select().order_by(self.table.columns.id)
for att, value in filter_.__dict__.items():
if len(value) > 0:
query = query.where(self.table.columns[att].in_(value))
rows = await self.db.fetch_all(query)
for row in rows:
# convert row from tuple to csv format and removing the first column (id)
line = [
v
for k, v in dict(row._mapping).items()
if k in self.cols and v is not None
]
persist.load_policy_line(", ".join(line), model)
def is_filtered(self):
return self.filtered
@staticmethod
def _policy_to_dict(p_type: str, rule: List[str]) -> Dict[str, str]:
row: dict = {"ptype": p_type}
for i, value in enumerate(rule):
row.update({f"v{i}": value})
return row