| import base64 |
| import binascii |
| |
| import casbin |
| import uvicorn |
| from fastapi import FastAPI |
| from starlette.authentication import AuthenticationBackend, AuthenticationError, SimpleUser, AuthCredentials |
| from starlette.middleware.authentication import AuthenticationMiddleware |
| |
| from fastapi_casbin_auth import CasbinMiddleware |
| |
| app = FastAPI() |
| |
| |
| class BasicAuth(AuthenticationBackend): |
| async def authenticate(self, request): |
| if "Authorization" not in request.headers: |
| return None |
| |
| auth = request.headers["Authorization"] |
| try: |
| scheme, credentials = auth.split() |
| decoded = base64.b64decode(credentials).decode("ascii") |
| except (ValueError, UnicodeDecodeError, binascii.Error): |
| raise AuthenticationError("Invalid basic auth credentials") |
| |
| username, _, password = decoded.partition(":") |
| return AuthCredentials(["authenticated"]), SimpleUser(username) |
| |
| |
| enforcer = casbin.Enforcer('../examples/rbac_model.conf', '../examples/rbac_policy.csv') |
| |
| app.add_middleware(CasbinMiddleware, enforcer=enforcer) |
| app.add_middleware(AuthenticationMiddleware, backend=BasicAuth()) |
| |
| |
| @app.get('/') |
| async def index(): |
| return "If you see this, you have been authenticated." |
| |
| |
| @app.get('/dataset1/protected') |
| async def auth_test(): |
| return "You must be alice to see this." |
| |
| |
| if __name__ == '__main__': |
| uvicorn.run('test:app', debug=True) |