| use rocket::{ |
| fairing::{Fairing, Info, Kind}, |
| http::Status, |
| request::{self, FromRequest, Request}, |
| Data, |
| }; |
| |
| use casbin::prelude::*; |
| use casbin::{CachedEnforcer, CoreApi, Result as CasbinResult}; |
| use parking_lot::RwLock; |
| use std::sync::Arc; |
| |
| pub struct CasbinVals { |
| pub subject: Option<String>, |
| pub domain: Option<String>, |
| } |
| |
| impl CasbinVals { |
| pub fn new(subject: Option<String>, domain: Option<String>) -> CasbinVals { |
| CasbinVals { subject, domain } |
| } |
| } |
| |
| #[derive(Clone)] |
| pub struct CasbinGuard(Option<Status>); |
| |
| impl<'a, 'r> FromRequest<'a, 'r> for CasbinGuard { |
| type Error = (); |
| |
| fn from_request(request: &'a Request<'r>) -> request::Outcome<CasbinGuard, ()> { |
| match *request.local_cache(|| CasbinGuard(Status::from_code(0))) { |
| CasbinGuard(Some(Status::Ok)) => { |
| request::Outcome::Success(CasbinGuard(Some(Status::Ok))) |
| } |
| CasbinGuard(Some(err_status)) => request::Outcome::Failure((err_status, ())), |
| _ => request::Outcome::Failure((Status::BadGateway, ())), |
| } |
| } |
| } |
| |
| #[derive(Clone)] |
| pub struct CasbinFairing { |
| pub enforcer: Arc<RwLock<CachedEnforcer>>, |
| } |
| |
| impl CasbinFairing { |
| pub async fn new<M: TryIntoModel, A: TryIntoAdapter>(m: M, a: A) -> CasbinResult<Self> { |
| let enforcer: CachedEnforcer = CachedEnforcer::new(m, a).await?; |
| Ok(CasbinFairing { |
| enforcer: Arc::new(RwLock::new(enforcer)), |
| }) |
| } |
| |
| pub fn get_enforcer(&mut self) -> Arc<RwLock<CachedEnforcer>> { |
| self.enforcer.clone() |
| } |
| |
| pub fn set_enforcer(e: Arc<RwLock<CachedEnforcer>>) -> CasbinFairing { |
| CasbinFairing { enforcer: e } |
| } |
| } |
| |
| impl Fairing for CasbinFairing { |
| fn info(&self) -> Info { |
| Info { |
| name: "Casbin Fairing", |
| kind: Kind::Request | Kind::Response, |
| } |
| } |
| |
| fn on_request(&self, request: &mut Request, _data: &Data) { |
| let cloned_enforce = self.enforcer.clone(); |
| let path = request.uri().path().to_owned(); |
| let action = request.method().as_str().to_owned(); |
| |
| let (subject, domain) = match request.local_cache(|| CasbinVals { |
| subject: None, |
| domain: None, |
| }) { |
| CasbinVals { |
| subject: Some(x), |
| domain: Some(y), |
| } => (Some(x.to_owned()), Some(y.to_owned())), |
| CasbinVals { |
| subject: Some(x), |
| domain: None, |
| } => (Some(x.to_owned()), None), |
| _ => (None, None), |
| }; |
| |
| if let Some(subject) = subject { |
| if let Some(domain) = domain { |
| let mut lock = cloned_enforce.write(); |
| match lock.enforce_mut(vec![subject, domain, path, action]) { |
| Ok(true) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::Ok))); |
| } |
| Ok(false) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::Forbidden))); |
| } |
| Err(_) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::BadGateway))); |
| } |
| }; |
| } else { |
| let mut lock = cloned_enforce.write(); |
| match lock.enforce_mut(vec![subject, path, action]) { |
| Ok(true) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::Ok))); |
| } |
| Ok(false) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::Forbidden))); |
| } |
| Err(_) => { |
| drop(lock); |
| request.local_cache(|| CasbinGuard(Some(Status::BadGateway))); |
| } |
| }; |
| } |
| } else { |
| request.local_cache(|| CasbinGuard(Some(Status::BadGateway))); |
| } |
| } |
| } |