|  | #  Licensed to the Apache Software Foundation (ASF) under one or more | 
|  | #  contributor license agreements.  See the NOTICE file distributed with | 
|  | #  this work for additional information regarding copyright ownership. | 
|  | #  The ASF licenses this file to You under the Apache License, Version 2.0 | 
|  | #  (the "License"); you may not use this file except in compliance with | 
|  | #  the License.  You may obtain a copy of the License at | 
|  | # | 
|  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | #  Unless required by applicable law or agreed to in writing, software | 
|  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
|  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | #  See the License for the specific language governing permissions and | 
|  | #  limitations under the License. | 
|  | # | 
|  |  | 
|  | from __future__ import annotations | 
|  | import abc | 
|  | from typing import Any, Literal | 
|  | from pathlib import Path | 
|  |  | 
|  | import pydantic | 
|  |  | 
|  | from airavata_auth.device_auth import AuthContext | 
|  |  | 
|  | # from .task import Task | 
|  | Task = Any | 
|  | States = Literal[ | 
|  | # Experiment States | 
|  | 'CREATED', | 
|  | 'VALIDATED', | 
|  | 'SCHEDULED', | 
|  | 'LAUNCHED', | 
|  | 'EXECUTING', | 
|  | 'CANCELING', | 
|  | 'CANCELED', | 
|  | 'COMPLETED', | 
|  | 'FAILED', | 
|  | # Job States | 
|  | 'SUBMITTED', | 
|  | 'QUEUED', | 
|  | 'ACTIVE', | 
|  | 'COMPLETE', | 
|  | 'CANCELED', | 
|  | 'FAILED', | 
|  | 'SUSPENDED', | 
|  | 'UNKNOWN', | 
|  | 'NON_CRITICAL_FAIL', | 
|  | ] | 
|  |  | 
|  | class Runtime(abc.ABC, pydantic.BaseModel): | 
|  |  | 
|  | id: str | 
|  | args: dict[str, str | int | float] = pydantic.Field(default={}) | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def execute(self, task: Task) -> None: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def execute_cmd(self, cmd: str, task: Task) -> bytes: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def status(self, task: Task) -> tuple[str, str]: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def signal(self, signal: str, task: Task) -> None: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def ls(self, task: Task) -> list[str]: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def upload(self, file: Path, task: Task) -> str: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def download(self, file: str, local_dir: str, task: Task) -> str: ... | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def cat(self, file: str, task: Task) -> bytes: ... | 
|  |  | 
|  | def __str__(self) -> str: | 
|  | return f"{self.__class__.__name__}(args={self.args})" | 
|  |  | 
|  | @staticmethod | 
|  | def create(id: str, args: dict[str, Any]) -> Runtime: | 
|  | if id == "mock": | 
|  | return Mock(**args) | 
|  | elif id == "remote": | 
|  | return Remote(**args) | 
|  | else: | 
|  | raise ValueError(f"Unknown runtime id: {id}") | 
|  |  | 
|  | @staticmethod | 
|  | def Remote(**kwargs): | 
|  | return Remote(**kwargs) | 
|  |  | 
|  | @staticmethod | 
|  | def Local(**kwargs): | 
|  | return Mock(**kwargs) | 
|  |  | 
|  |  | 
|  | class Mock(Runtime): | 
|  |  | 
|  | _state: int = 0 | 
|  |  | 
|  | def __init__(self) -> None: | 
|  | super().__init__(id="mock", args={}) | 
|  |  | 
|  | def execute(self, task: Task) -> None: | 
|  | import uuid | 
|  | task.agent_ref = str(uuid.uuid4()) | 
|  | task.ref = str(uuid.uuid4()) | 
|  |  | 
|  | def execute_cmd(self, cmd: str, task: Task) -> bytes: | 
|  | return b"" | 
|  |  | 
|  | def execute_py(self, libraries: list[str], code: str, task: Task) -> None: | 
|  | pass | 
|  |  | 
|  | def status(self, task: Task) -> tuple[str, str]: | 
|  | import random | 
|  |  | 
|  | self._state += random.randint(0, 5) | 
|  | if self._state > 10: | 
|  | return "N/A", "COMPLETED" | 
|  | return "N/A", "RUNNING" | 
|  |  | 
|  | def signal(self, signal: str, task: Task) -> None: | 
|  | pass | 
|  |  | 
|  | def ls(self, task: Task) -> list[str]: | 
|  | return [""] | 
|  |  | 
|  | def upload(self, file: Path, task: Task) -> str: | 
|  | return "" | 
|  |  | 
|  | def download(self, file: str, local_dir: str, task: Task) -> str: | 
|  | return "" | 
|  |  | 
|  | def cat(self, file: str, task: Task) -> bytes: | 
|  | return b"" | 
|  |  | 
|  |  | 
|  | class Remote(Runtime): | 
|  |  | 
|  | def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int, gpu_count: int = 0, group: str = "Default") -> None: | 
|  | super().__init__(id="remote", args=dict( | 
|  | cluster=cluster, | 
|  | category=category, | 
|  | queue_name=queue_name, | 
|  | node_count=node_count, | 
|  | cpu_count=cpu_count, | 
|  | gpu_count=gpu_count, | 
|  | walltime=walltime, | 
|  | group=group, | 
|  | )) | 
|  |  | 
|  | def execute(self, task: Task) -> None: | 
|  | assert task.ref is None | 
|  | assert task.agent_ref is None | 
|  | assert {"cluster", "group", "queue_name", "node_count", "cpu_count", "gpu_count", "walltime"}.issubset(self.args.keys()) | 
|  | print(f"[Remote] Creating Experiment: name={task.name}") | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | try: | 
|  | launch_state = av.launch_experiment( | 
|  | experiment_name=task.name, | 
|  | app_name=task.app_id, | 
|  | project=task.project, | 
|  | inputs=task.inputs, | 
|  | computation_resource_name=str(self.args["cluster"]), | 
|  | queue_name=str(self.args["queue_name"]), | 
|  | node_count=int(self.args["node_count"]), | 
|  | cpu_count=int(self.args["cpu_count"]), | 
|  | walltime=int(self.args["walltime"]), | 
|  | group=str(self.args["group"]), | 
|  | ) | 
|  | task.agent_ref = launch_state.agent_ref | 
|  | task.pid = launch_state.process_id | 
|  | task.ref = launch_state.experiment_id | 
|  | task.workdir = launch_state.experiment_dir | 
|  | task.sr_host = launch_state.sr_host | 
|  | print(f"[Remote] Experiment Launched: id={task.ref}") | 
|  | except Exception as e: | 
|  | print(f"[Remote] Failed to launch experiment: {repr(e)}") | 
|  | raise e | 
|  |  | 
|  | def execute_cmd(self, cmd: str, task: Task) -> bytes: | 
|  | assert task.ref is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.pid is not None | 
|  | assert task.sr_host is not None | 
|  | assert task.workdir is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | try: | 
|  | result = av.execute_cmd(task.agent_ref, cmd) | 
|  | return result | 
|  | except Exception as e: | 
|  | print(f"[Remote] Failed to execute command: {repr(e)}") | 
|  | return b"" | 
|  |  | 
|  |  | 
|  | def execute_py(self, libraries: list[str], code: str, task: Task) -> None: | 
|  | assert task.ref is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.pid is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | result = av.execute_py(task.project, libraries, code, task.agent_ref, task.pid, task.runtime.args) | 
|  | print(result) | 
|  |  | 
|  | def status(self, task: Task) -> tuple[str, States]: | 
|  | assert task.ref is not None | 
|  | assert task.agent_ref is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | # prioritize job state, fallback to experiment state | 
|  | job_id, job_state = av.get_task_status(task.ref) | 
|  | if job_state in [AiravataOperator.JobState.UNKNOWN, AiravataOperator.JobState.NON_CRITICAL_FAIL]: | 
|  | return job_id, av.get_experiment_status(task.ref).name | 
|  | else: | 
|  | return job_id, job_state.name | 
|  |  | 
|  | def signal(self, signal: str, task: Task) -> None: | 
|  | assert task.ref is not None | 
|  | assert task.agent_ref is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | av.stop_experiment(task.ref) | 
|  |  | 
|  | def ls(self, task: Task) -> list[str]: | 
|  | assert task.ref is not None | 
|  | assert task.pid is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.sr_host is not None | 
|  | assert task.workdir is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | files = av.list_files(task.pid, task.agent_ref, task.sr_host, task.workdir) | 
|  | return files | 
|  |  | 
|  | def upload(self, file: Path, task: Task) -> str: | 
|  | assert task.ref is not None | 
|  | assert task.pid is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.sr_host is not None | 
|  | assert task.workdir is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | result = av.upload_files(task.pid, task.agent_ref, task.sr_host, [file], task.workdir).pop() | 
|  | return result | 
|  |  | 
|  | def download(self, file: str, local_dir: str, task: Task) -> str: | 
|  | assert task.ref is not None | 
|  | assert task.pid is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.sr_host is not None | 
|  | assert task.workdir is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | result = av.download_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir, local_dir) | 
|  | return result | 
|  |  | 
|  | def cat(self, file: str, task: Task) -> bytes: | 
|  | assert task.ref is not None | 
|  | assert task.pid is not None | 
|  | assert task.agent_ref is not None | 
|  | assert task.sr_host is not None | 
|  | assert task.workdir is not None | 
|  |  | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | content = av.cat_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir) | 
|  | return content | 
|  |  | 
|  |  | 
|  | def find_runtimes( | 
|  | cluster: str | None = None, | 
|  | category: str | None = None, | 
|  | node_count: int | None = None, | 
|  | cpu_count: int | None = None, | 
|  | group: str | None = None, | 
|  | queue_name: str | None = None, | 
|  | ) -> list[Runtime]: | 
|  | from .airavata import AiravataOperator | 
|  | av = AiravataOperator(AuthContext.get_access_token()) | 
|  | grps = av.get_available_groups() | 
|  | grp_names = [str(x.groupResourceProfileName) for x in grps] | 
|  | if group is not None: | 
|  | assert group in grp_names, f"Group {group} was not found. Available groups: {repr(grp_names)}" | 
|  | groups = [g for g in grps if str(g.groupResourceProfileName) == group] | 
|  | else: | 
|  | groups = grps | 
|  | runtimes = [] | 
|  | for g in groups: | 
|  | matched_runtimes = [] | 
|  | assert g.groupResourceProfileName is not None, f"Group {g} has no name" | 
|  | r: Runtime | 
|  | for r in av.get_available_runtimes(group=g.groupResourceProfileName): | 
|  | if (node_count or 1) > int(r.args["node_count"]): | 
|  | continue | 
|  | if (cpu_count or 1) > int(r.args["cpu_count"]): | 
|  | continue | 
|  | if (cluster or r.args["cluster"]) != r.args["cluster"]: | 
|  | continue | 
|  | if (category or r.args["category"]) != r.args["category"]: | 
|  | continue | 
|  | if (queue_name or r.args["queue_name"]) != r.args["queue_name"]: | 
|  | continue | 
|  | matched_runtimes.append(r) | 
|  | runtimes.extend(matched_runtimes) | 
|  | return runtimes | 
|  |  | 
|  | def is_terminal_state(x: States) -> bool: | 
|  | return x in ["CANCELED", "COMPLETE", "COMPLETED", "FAILED"] |