blob: 944055ad26868d699a5fbf4b12c0831c1cc4030f [file]
import os
import time
import requests
from rich.console import Console
from airavata_sdk import Settings
# Load environment variables from .env file
class AuthContext:
@staticmethod
def get_access_token():
if os.environ.get("CS_ACCESS_TOKEN", None) is None:
context = AuthContext()
context.login()
return os.environ["CS_ACCESS_TOKEN"]
def __init__(self):
self.settings = Settings()
if not self.settings.AUTH_CLIENT_ID or not self.settings.AUTH_REALM or not self.settings.AUTH_SERVER_URL:
raise ValueError("Missing required environment variables for client ID, realm, or auth server URL")
self.device_code = None
self.interval = None
self.console = Console()
def login(self):
if os.environ.get('CS_ACCESS_TOKEN', None) is not None:
return
# Step 1: Request device and user code
auth_device_url = f"{self.settings.AUTH_SERVER_URL}/realms/{self.settings.AUTH_REALM}/protocol/openid-connect/auth/device"
response = requests.post(auth_device_url, data={
"client_id": self.settings.AUTH_CLIENT_ID, "scope": "openid"})
if response.status_code != 200:
print(
f"Error in authentication request: {response.status_code} - {response.text}", flush=True)
return
data = response.json()
self.device_code = data.get("device_code")
self.interval = data.get("interval", 5)
# Step 2: Poll for the token
self.poll_for_token(data.get('verification_uri_complete'))
def poll_for_token(self, url):
assert self.interval is not None
token_url = f"{self.settings.AUTH_SERVER_URL}/realms/{self.settings.AUTH_REALM}/protocol/openid-connect/token"
counter = 0
with self.console.status(f"Authenticate via link: [link={url}]{url}[/link]", refresh_per_second=1) as status:
while True:
response = requests.post(token_url, data={
"client_id": self.settings.AUTH_CLIENT_ID,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": self.device_code
})
if response.status_code == 200:
data = response.json()
access_token = data.get("access_token")
print(f"Authenticated.")
os.environ['CS_ACCESS_TOKEN'] = access_token
break
elif response.status_code == 400 and response.json().get("error") == "authorization_pending":
counter += 1
status.update(
f"Authenticate via link: [link={url}]{url}[/link] ({counter})")
else:
print(
f"Error during authentication: {response.status_code} - {response.text}")
break
time.sleep(self.interval)
status.stop()
self.console.clear()