| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Module for managing ML models in Apache Beam pipelines. |
| |
| This module provides classes and functions to efficiently manage multiple |
| machine learning models within Apache Beam pipelines. It includes functionality |
| for loading, caching, and updating models using multi-process shared memory, |
| ensuring that models are reused across different workers to optimize resource |
| usage and performance. |
| """ |
| |
| import gc |
| import heapq |
| import itertools |
| import logging |
| import subprocess |
| import threading |
| import time |
| from collections import Counter |
| from collections import OrderedDict |
| from collections import defaultdict |
| from collections import deque |
| from typing import Any |
| from typing import Callable |
| from typing import Dict |
| from typing import Optional |
| from typing import Tuple |
| |
| import numpy as np |
| import torch |
| from scipy.optimize import nnls |
| |
| from apache_beam.utils import multi_process_shared |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class GPUMonitor: |
| """Monitors GPU memory usage in a separate thread using nvidia-smi. |
| |
| This class continuously polls GPU memory statistics to track current usage |
| and peak usage over a sliding time window. It serves as the source of truth |
| for the ModelManager's resource decisions. |
| |
| Attributes: |
| fallback_memory_mb: Default total memory if hardware detection fails. |
| poll_interval: Seconds between memory checks. |
| peak_window_seconds: Duration to track peak memory usage. |
| """ |
| def __init__( |
| self, |
| fallback_memory_mb: float = 16000.0, |
| poll_interval: float = 0.5, |
| peak_window_seconds: float = 30.0): |
| self._current_usage = 0.0 |
| self._peak_usage = 0.0 |
| self._total_memory = fallback_memory_mb |
| self._poll_interval = poll_interval |
| self._peak_window_seconds = peak_window_seconds |
| self._memory_history = deque() |
| self._running = False |
| self._thread = None |
| self._lock = threading.Lock() |
| |
| def _detect_hardware(self): |
| try: |
| cmd = [ |
| "nvidia-smi", |
| "--query-gpu=memory.total", |
| "--format=csv,noheader,nounits" |
| ] |
| output = subprocess.check_output(cmd, text=True).strip() |
| self._total_memory = float(output) |
| return True |
| except (FileNotFoundError, subprocess.CalledProcessError): |
| logger.warning( |
| "nvidia-smi not found or failed. Defaulting total memory to %s MB", |
| self._total_memory) |
| return False |
| except Exception as e: |
| logger.warning( |
| "Error parsing nvidia-smi output: %s. " |
| "Defaulting total memory to %s MB", |
| e, |
| self._total_memory) |
| return False |
| |
| def start(self): |
| self._gpu_available = self._detect_hardware() |
| if self._running or not self._gpu_available: |
| return |
| self._running = True |
| self._thread = threading.Thread(target=self._poll_loop, daemon=True) |
| self._thread.start() |
| |
| def stop(self): |
| self._running = False |
| if self._thread: |
| self._thread.join() |
| |
| def reset_peak(self): |
| with self._lock: |
| now = time.time() |
| self._memory_history.clear() |
| self._memory_history.append((now, self._current_usage)) |
| self._peak_usage = self._current_usage |
| |
| def get_stats(self) -> Tuple[float, float, float]: |
| with self._lock: |
| return self._current_usage, self._peak_usage, self._total_memory |
| |
| def refresh(self): |
| """Forces an immediate poll of the GPU.""" |
| usage = self._get_nvidia_smi_used() |
| now = time.time() |
| with self._lock: |
| self._current_usage = usage |
| self._memory_history.append((now, usage)) |
| # Recalculate peak immediately |
| while self._memory_history and (now - self._memory_history[0][0] |
| > self._peak_window_seconds): |
| self._memory_history.popleft() |
| self._peak_usage = ( |
| max(m for _, m in self._memory_history) |
| if self._memory_history else usage) |
| |
| def _get_nvidia_smi_used(self) -> float: |
| try: |
| cmd = [ |
| "nvidia-smi", |
| "--query-gpu=memory.free", |
| "--format=csv,noheader,nounits" |
| ] |
| output = subprocess.check_output(cmd, text=True).strip() |
| free_memory = float(output) |
| return self._total_memory - free_memory |
| except Exception as e: |
| logger.warning('Failed to get GPU memory usage: %s', e) |
| return 0.0 |
| |
| def _poll_loop(self): |
| while self._running: |
| usage = self._get_nvidia_smi_used() |
| now = time.time() |
| with self._lock: |
| self._current_usage = usage |
| self._memory_history.append((now, usage)) |
| while self._memory_history and (now - self._memory_history[0][0] |
| > self._peak_window_seconds): |
| self._memory_history.popleft() |
| self._peak_usage = ( |
| max(m for _, m in self._memory_history) |
| if self._memory_history else usage) |
| time.sleep(self._poll_interval) |
| |
| |
| class ResourceEstimator: |
| """Estimates individual model memory usage using statistical observation. |
| |
| Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of |
| individual models based on aggregate system memory readings and the |
| configuration of active models at that time. |
| """ |
| def __init__( |
| self, |
| smoothing_factor: float = 0.2, |
| min_data_points: int = 5, |
| verbose_logging: bool = False): |
| self.smoothing_factor = smoothing_factor |
| self.min_data_points = min_data_points |
| self.verbose_logging = verbose_logging |
| self.estimates: Dict[str, float] = {} |
| self.history = defaultdict(lambda: deque(maxlen=20)) |
| self.known_models = set() |
| self._lock = threading.Lock() |
| |
| def logging_info(self, message: str, *args): |
| if self.verbose_logging: |
| logger.info(message, *args) |
| |
| def is_unknown(self, model_tag: str) -> bool: |
| with self._lock: |
| return model_tag not in self.estimates |
| |
| def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float: |
| with self._lock: |
| return self.estimates.get(model_tag, default_mb) |
| |
| def set_initial_estimate(self, model_tag: str, cost: float): |
| with self._lock: |
| self.estimates[model_tag] = cost |
| self.known_models.add(model_tag) |
| self.logging_info("Initial Profile for %s: %s MB", model_tag, cost) |
| |
| def add_observation( |
| self, active_snapshot: Dict[str, int], peak_memory: float): |
| if active_snapshot: |
| model_list = "\n".join( |
| f"\t- {model}: {count}" |
| for model, count in sorted(active_snapshot.items())) |
| else: |
| model_list = "\t- None" |
| |
| self.logging_info( |
| "Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s", |
| peak_memory, |
| model_list) |
| if not active_snapshot: |
| return |
| with self._lock: |
| config_key = tuple(sorted(active_snapshot.items())) |
| self.history[config_key].append(peak_memory) |
| for tag in active_snapshot: |
| self.known_models.add(tag) |
| self._solve() |
| |
| def _solve(self): |
| """ |
| Solves Ax=b using raw readings (no pre-averaging) and NNLS. |
| This creates a 'tall' matrix A where every memory reading is |
| a separate equation. |
| """ |
| unique = sorted(list(self.known_models)) |
| |
| # We need to build the matrix first to know if we have enough data points |
| A, b = [], [] |
| |
| for config_key, mem_values in self.history.items(): |
| if not mem_values: |
| continue |
| |
| # 1. Create the feature row for this configuration ONCE |
| # (It represents the model counts + bias) |
| counts = dict(config_key) |
| feature_row = [counts.get(model, 0) for model in unique] |
| feature_row.append(1) # Bias column |
| |
| # 2. Add a separate row to the matrix for EVERY individual reading |
| # Instead of averaging, we flatten the history into the matrix |
| for reading in mem_values: |
| A.append(feature_row) # The inputs (models) stay the same |
| b.append(reading) # The output (memory) varies due to noise |
| |
| # Convert to numpy for SciPy |
| A = np.array(A) |
| b = np.array(b) |
| |
| if len( |
| self.history.keys()) < len(unique) + 1 or len(A) < self.min_data_points: |
| # Not enough data to solve yet |
| return |
| |
| self.logging_info( |
| "Solving with %s total observations for %s models.", |
| len(A), |
| len(unique)) |
| |
| try: |
| # Solve using Non-Negative Least Squares |
| # x will be >= 0 |
| x, _ = nnls(A, b) |
| |
| weights = x[:-1] |
| bias = x[-1] |
| |
| for i, model in enumerate(unique): |
| calculated_cost = weights[i] |
| |
| if model in self.estimates: |
| old = self.estimates[model] |
| new = (old * (1 - self.smoothing_factor)) + ( |
| calculated_cost * self.smoothing_factor) |
| self.estimates[model] = new |
| else: |
| self.estimates[model] = calculated_cost |
| |
| self.logging_info( |
| "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) |
| self.logging_info("System Bias: %s MB", bias) |
| |
| except Exception as e: |
| logger.error("Solver failed: %s", e) |
| |
| |
| class QueueTicket: |
| def __init__(self, priority, ticket_num, tag): |
| self.priority = priority |
| self.ticket_num = ticket_num |
| self.tag = tag |
| self.wake_event = threading.Event() |
| |
| def __lt__(self, other): |
| return (self.priority, self.ticket_num) < (other.priority, other.ticket_num) |
| |
| |
| class ModelManager: |
| """Manages model lifecycles, caching, and resource arbitration. |
| |
| This class acts as the central controller for acquiring model instances. |
| |
| 1. LRU Caching of idle models. |
| 2. Resource estimation and admission control (preventing OOM). |
| 3. Dynamic eviction of low-priority models, determined by count of |
| pending requests, when space is needed. |
| 4. 'Isolation Mode' for safely profiling unknown models. |
| """ |
| def __init__( |
| self, |
| monitor: Optional['GPUMonitor'] = None, |
| slack_percentage: float = 0.10, |
| poll_interval: float = 0.5, |
| peak_window_seconds: float = 30.0, |
| min_data_points: int = 5, |
| smoothing_factor: float = 0.2, |
| eviction_cooldown_seconds: float = 10.0, |
| min_model_copies: int = 1, |
| wait_timeout_seconds: float = 300.0, |
| lock_timeout_seconds: float = 60.0, |
| verbose_logging: bool = False): |
| |
| self._estimator = ResourceEstimator( |
| min_data_points=min_data_points, |
| smoothing_factor=smoothing_factor, |
| verbose_logging=verbose_logging) |
| self._monitor = monitor if monitor else GPUMonitor( |
| poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) |
| self._slack_percentage = slack_percentage |
| |
| self._eviction_cooldown = eviction_cooldown_seconds |
| self._min_model_copies = min_model_copies |
| self._wait_timeout_seconds = wait_timeout_seconds |
| self._lock_timeout_seconds = lock_timeout_seconds |
| self._verbose_logging = verbose_logging |
| |
| # Resource State |
| self._models = defaultdict(list) |
| # Idle LRU used to track released models that |
| # can be freed or reused upon request. |
| self._idle_lru = OrderedDict() |
| self._active_counts = Counter() |
| self._total_active_jobs = 0 |
| self._pending_reservations = 0.0 |
| |
| # Isolation state used to profile unknown models, |
| # ensuring they run alone to get accurate readings. |
| # isolation_baseline represents the GPU usage before |
| # loading the unknown model. |
| self._isolation_mode = False |
| self._isolation_baseline = 0.0 |
| |
| # Waiting Queue and Ticketing to make sure we have fair ordering |
| # and also priority for unknown models. |
| self._wait_queue = [] |
| self._ticket_counter = itertools.count() |
| self._cancelled_tickets = set() |
| # TODO: Consider making the wait to be smarter, i.e. |
| # splitting read/write etc. to avoid potential contention. |
| self._cv = threading.Condition() |
| |
| self._monitor.start() |
| |
| def logging_info(self, message: str, *args): |
| if self._verbose_logging: |
| logger.info(message, *args) |
| |
| def all_models(self, tag) -> list[Any]: |
| return self._models[tag] |
| |
| # Should hold _cv lock when calling |
| def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: |
| if self._total_active_jobs > 0: |
| self.logging_info( |
| "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num) |
| self._cv.wait(timeout=self._lock_timeout_seconds) |
| # return False since we have waited and need to re-evaluate |
| # in caller to make sure our priority is still valid. |
| return False |
| |
| self.logging_info("Unknown model %s detected. Flushing GPU.", tag) |
| self._delete_all_models() |
| |
| self._isolation_mode = True |
| self._total_active_jobs += 1 |
| self._isolation_baseline, _, _ = self._monitor.get_stats() |
| self._monitor.reset_peak() |
| return True |
| |
| # Should hold _cv lock when calling |
| def should_spawn_model(self, tag: str, ticket_num: int) -> bool: |
| curr, _, total = self._monitor.get_stats() |
| est_cost = self._estimator.get_estimate(tag) |
| limit = total * (1 - self._slack_percentage) |
| |
| # Use current usage for capacity check (ignore old spikes) |
| if (curr + self._pending_reservations + est_cost) <= limit: |
| self._pending_reservations += est_cost |
| self._total_active_jobs += 1 |
| self._active_counts[tag] += 1 |
| return True |
| |
| # Evict to make space (passing tag to check demand/existence) |
| if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): |
| return True |
| |
| # Manually log status for debugging if we are going to wait |
| idle_count = 0 |
| other_idle_count = 0 |
| for item in self._idle_lru.items(): |
| if item[1][0] == tag: |
| idle_count += 1 |
| else: |
| other_idle_count += 1 |
| total_model_count = 0 |
| for _, instances in self._models.items(): |
| total_model_count += len(instances) |
| curr, _, _ = self._monitor.get_stats() |
| self.logging_info( |
| "Waiting for resources to free up: " |
| "tag=%s ticket num%s model count=%s " |
| "idle count=%s resource usage=%.1f MB " |
| "total models count=%s other idle=%s", |
| tag, |
| ticket_num, |
| len(self._models[tag]), |
| idle_count, |
| curr, |
| total_model_count, |
| other_idle_count) |
| # Wait since we couldn't make space and |
| # added timeout to avoid missed notify call. |
| self._cv.wait(timeout=self._lock_timeout_seconds) |
| return False |
| |
| def _wake_next_in_queue(self): |
| if self._wait_queue: |
| # Clean up cancelled tickets at head of queue |
| while self._wait_queue and self._wait_queue[ |
| 0].ticket_num in self._cancelled_tickets: |
| self._cancelled_tickets.remove(self._wait_queue[0].ticket_num) |
| heapq.heappop(self._wait_queue) |
| next_inline = self._wait_queue[0] |
| next_inline.wake_event.set() |
| |
| def _wait_in_queue(self, ticket: QueueTicket): |
| self._cv.release() |
| try: |
| ticket.wake_event.wait(timeout=self._lock_timeout_seconds) |
| ticket.wake_event.clear() |
| finally: |
| self._cv.acquire() |
| |
| def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: |
| current_priority = 0 if self._estimator.is_unknown(tag) else 1 |
| ticket_num = next(self._ticket_counter) |
| my_ticket = QueueTicket(current_priority, ticket_num, tag) |
| |
| with self._cv: |
| # FAST PATH: Grab from idle LRU if available |
| if not self._isolation_mode: |
| cached_instance = self._try_grab_from_lru(tag) |
| if cached_instance: |
| return cached_instance |
| |
| # SLOW PATH: Enqueue and wait for turn to acquire model, |
| # with unknown models having priority and order enforced |
| # by ticket number as FIFO. |
| self.logging_info( |
| "Acquire Queued: tag=%s, priority=%d " |
| "total models count=%s ticket num=%s", |
| tag, |
| current_priority, |
| len(self._models[tag]), |
| ticket_num) |
| heapq.heappush(self._wait_queue, my_ticket) |
| |
| est_cost = 0.0 |
| is_unknown = False |
| wait_time_start = time.time() |
| |
| try: |
| while True: |
| wait_time_elapsed = time.time() - wait_time_start |
| if wait_time_elapsed > self._wait_timeout_seconds: |
| raise RuntimeError( |
| f"Timeout waiting to acquire model: {tag} " |
| f"after {wait_time_elapsed:.1f} seconds.") |
| if not self._wait_queue or self._wait_queue[ |
| 0].ticket_num != ticket_num: |
| self.logging_info( |
| "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) |
| self._wait_in_queue(my_ticket) |
| continue |
| |
| # Re-evaluate priority in case model became known during wait |
| is_unknown = self._estimator.is_unknown(tag) |
| real_priority = 0 if is_unknown else 1 |
| |
| # If priority changed, reinsert into queue and wait |
| if current_priority != real_priority: |
| heapq.heappop(self._wait_queue) |
| current_priority = real_priority |
| my_ticket = QueueTicket(current_priority, ticket_num, tag) |
| heapq.heappush(self._wait_queue, my_ticket) |
| self._wake_next_in_queue() |
| continue |
| |
| # Try grab from LRU again in case model was released during wait |
| cached_instance = self._try_grab_from_lru(tag) |
| if cached_instance: |
| return cached_instance |
| |
| # Path A: Isolation |
| if is_unknown: |
| if self.try_enter_isolation_mode(tag, ticket_num): |
| # We got isolation, can proceed to spawn |
| break |
| else: |
| # We waited, need to re-evaluate our turn |
| # because priority may have changed during the wait |
| continue |
| |
| # Path B: Concurrent |
| else: |
| if self._isolation_mode: |
| self.logging_info( |
| "Waiting due to isolation in progress: tag=%s ticket num%s", |
| tag, |
| ticket_num) |
| self._wait_in_queue(my_ticket) |
| continue |
| |
| if self.should_spawn_model(tag, ticket_num): |
| est_cost = self._estimator.get_estimate(tag) |
| # We can proceed to spawn since we have resources |
| break |
| else: |
| # We waited, need to re-evaluate our turn |
| # because priority may have changed during the wait |
| continue |
| |
| finally: |
| # Remove self from wait queue once done |
| if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num: |
| heapq.heappop(self._wait_queue) |
| else: |
| # Marked as cancelled so that we skip when we reach head later |
| self._cancelled_tickets.add(ticket_num) |
| self._wake_next_in_queue() |
| |
| return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) |
| |
| def release_model(self, tag: str, instance: Any): |
| with self._cv: |
| try: |
| self._total_active_jobs -= 1 |
| if self._active_counts[tag] > 0: |
| self._active_counts[tag] -= 1 |
| |
| self._idle_lru[id(instance)] = (tag, instance, time.time()) |
| |
| # Update estimator with latest stats |
| _, peak_during_job, _ = self._monitor.get_stats() |
| |
| if self._isolation_mode and self._active_counts[tag] == 0: |
| # For isolation mode, we directly set the initial estimate |
| # so that we can quickly learn the model cost. |
| cost = max(0, peak_during_job - self._isolation_baseline) |
| self._estimator.set_initial_estimate(tag, cost) |
| self._isolation_mode = False |
| self._isolation_baseline = 0.0 |
| else: |
| # Regular update for known models |
| snapshot = { |
| t: len(instances) |
| for t, instances in self._models.items() if len(instances) > 0 |
| } |
| if snapshot: |
| self._estimator.add_observation(snapshot, peak_during_job) |
| |
| finally: |
| self._wake_next_in_queue() |
| self._cv.notify_all() |
| |
| def _try_grab_from_lru(self, tag: str) -> Any: |
| target_key = None |
| target_instance = None |
| |
| for key, (t, instance, _) in reversed(self._idle_lru.items()): |
| if t == tag: |
| target_key = key |
| target_instance = instance |
| break |
| |
| if target_instance: |
| # Found an idle model, remove from LRU and return |
| del self._idle_lru[target_key] |
| self._active_counts[tag] += 1 |
| self._total_active_jobs += 1 |
| return target_instance |
| |
| self.logging_info("No idle model found for tag: %s", tag) |
| return None |
| |
| def _evict_to_make_space( |
| self, limit: float, est_cost: float, requesting_tag: str) -> bool: |
| """ |
| Evicts models based on Demand Magnitude + Tiers. |
| Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction |
| of the lowest-demand candidate to avoid starvation. |
| Returns True if space was made, False otherwise. |
| """ |
| curr, _, _ = self._monitor.get_stats() |
| projected_usage = curr + self._pending_reservations + est_cost |
| |
| if projected_usage <= limit: |
| # Memory usage changed and we are already under limit |
| return True |
| |
| now = time.time() |
| |
| # Calculate the demand from the wait queue |
| # TODO: Also factor in the active counts to avoid thrashing |
| demand_map = Counter() |
| for item in self._wait_queue: |
| demand_map[item.tag] += 1 |
| |
| my_demand = demand_map[requesting_tag] |
| am_i_starving = len(self._models[requesting_tag]) == 0 |
| |
| candidates = [] |
| for key, (tag, instance, release_time) in self._idle_lru.items(): |
| candidate_demand = demand_map[tag] |
| |
| # TODO: Try to avoid churn if demand is similar |
| if not am_i_starving and candidate_demand >= my_demand: |
| continue |
| |
| # Attempts to score candidates based on hotness and manually |
| # specified minimum copies. Demand is weighted heavily to |
| # ensure we evict low-demand models first. |
| age = now - release_time |
| is_cold = age >= self._eviction_cooldown |
| |
| total_copies = len(self._models[tag]) |
| is_surplus = total_copies > self._min_model_copies |
| |
| if is_cold and is_surplus: tier = 0 |
| elif not is_cold and is_surplus: tier = 1 |
| elif is_cold and not is_surplus: tier = 2 |
| else: tier = 3 |
| |
| score = (candidate_demand * 10) + tier |
| |
| candidates.append((score, release_time, key, tag, instance)) |
| |
| candidates.sort(key=lambda x: (x[0], x[1])) |
| |
| # Evict candidates until we are under limit |
| for score, _, key, tag, instance in candidates: |
| if projected_usage <= limit: |
| break |
| |
| if key not in self._idle_lru: continue |
| |
| self._perform_eviction(key, tag, instance, score) |
| |
| curr, _, _ = self._monitor.get_stats() |
| projected_usage = curr + self._pending_reservations + est_cost |
| |
| return projected_usage <= limit |
| |
| def _delete_instance(self, instance: Any): |
| if isinstance(instance, str): |
| # If the instance is a string, it's a uuid used |
| # to retrieve the model from MultiProcessShared |
| try: |
| multi_process_shared.MultiProcessShared( |
| lambda: "N/A", tag=instance).unsafe_hard_delete() |
| except (EOFError, OSError, BrokenPipeError): |
| # This can happen even in normal operation. |
| pass |
| if hasattr(instance, 'mock_model_unsafe_hard_delete'): |
| # Call the mock unsafe hard delete method for testing |
| instance.mock_model_unsafe_hard_delete() |
| del instance |
| |
| def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): |
| self.logging_info("Evicting Model: %s (Score %d)", tag, score) |
| curr, _, _ = self._monitor.get_stats() |
| self.logging_info("Resource Usage Before Eviction: %.1f MB", curr) |
| |
| if key in self._idle_lru: |
| del self._idle_lru[key] |
| |
| for i, inst in enumerate(self._models[tag]): |
| if instance == inst: |
| del self._models[tag][i] |
| break |
| |
| self._delete_instance(instance) |
| gc.collect() |
| torch.cuda.empty_cache() |
| self._monitor.refresh() |
| self._monitor.reset_peak() |
| curr, _, _ = self._monitor.get_stats() |
| self.logging_info("Resource Usage After Eviction: %.1f MB", curr) |
| |
| def _spawn_new_model( |
| self, |
| tag: str, |
| loader_func: Callable[[], Any], |
| is_unknown: bool, |
| est_cost: float) -> Any: |
| try: |
| with self._cv: |
| self.logging_info("Loading Model: %s (Unknown: %s)", tag, is_unknown) |
| baseline_snap, _, _ = self._monitor.get_stats() |
| instance = loader_func() |
| _, peak_during_load, _ = self._monitor.get_stats() |
| |
| snapshot = {tag: 1} |
| self._estimator.add_observation( |
| snapshot, peak_during_load - baseline_snap) |
| |
| if not is_unknown: |
| self._pending_reservations = max( |
| 0.0, self._pending_reservations - est_cost) |
| self._models[tag].append(instance) |
| return instance |
| |
| except Exception as e: |
| logger.error("Load Failed: %s. Error: %s", tag, e) |
| with self._cv: |
| self._total_active_jobs -= 1 |
| if is_unknown: |
| self._isolation_mode = False |
| self._isolation_baseline = 0.0 |
| else: |
| self._pending_reservations = max( |
| 0.0, self._pending_reservations - est_cost) |
| self._active_counts[tag] -= 1 |
| self._cv.notify_all() |
| raise e |
| |
| def _delete_all_models(self): |
| self._idle_lru.clear() |
| for _, instances in self._models.items(): |
| for instance in instances: |
| self._delete_instance(instance) |
| self._models.clear() |
| self._active_counts.clear() |
| gc.collect() |
| torch.cuda.empty_cache() |
| self._monitor.refresh() |
| self._monitor.reset_peak() |
| |
| def _force_reset(self): |
| logger.warning("Force Reset Triggered") |
| self._delete_all_models() |
| self._models = defaultdict(list) |
| self._idle_lru = OrderedDict() |
| self._active_counts = Counter() |
| self._wait_queue = [] |
| self._total_active_jobs = 0 |
| self._pending_reservations = 0.0 |
| self._isolation_mode = False |
| self._isolation_baseline = 0.0 |
| |
| def shutdown(self): |
| self._delete_all_models() |
| self._monitor.stop() |
| |
| def __del__(self): |
| self.shutdown() |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| self.shutdown() |