import asyncio import time import logging from typing import Optional import gc import os _proc = None try: import psutil # type: ignore _proc = psutil.Process(os.getpid()) except Exception: psutil = None # type: ignore def _rss_mb() -> float: """Return current process RSS in MB, or -1.0 if unavailable.""" global _proc try: if _proc is None and psutil is not None: _proc = psutil.Process(os.getpid()) if _proc is not None: return _proc.memory_info().rss / (1024 * 1024) except Exception: return -1.0 return -1.0 try: import torch # Optional; used for cache cleanup metrics except Exception: # pragma: no cover - torch may not be present in some envs torch = None # type: ignore from app import config from app.services.tts_service import TTSService logger = logging.getLogger(__name__) class ModelManager: _instance: Optional["ModelManager"] = None def __init__(self): self._service: Optional[TTSService] = None self._last_used: float = time.time() self._active: int = 0 self._lock = asyncio.Lock() self._counter_lock = asyncio.Lock() @classmethod def instance(cls) -> "ModelManager": if not cls._instance: cls._instance = cls() return cls._instance async def _ensure_service(self) -> None: if self._service is None: # Use configured device, default is handled by TTSService itself device = getattr(config, "DEVICE", "auto") # TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu if device == "auto": try: import torch if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" except Exception: device = "cpu" self._service = TTSService(device=device) async def load(self) -> None: async with self._lock: await self._ensure_service() if self._service and self._service.model is None: before_mb = _rss_mb() logger.info( "Loading TTS model (device=%s)... (rss_before=%.1f MB)", self._service.device, before_mb, ) self._service.load_model() after_mb = _rss_mb() if after_mb >= 0 and before_mb >= 0: logger.info( "TTS model loaded (rss_after=%.1f MB, delta=%.1f MB)", after_mb, after_mb - before_mb, ) self._last_used = time.time() async def unload(self) -> None: async with self._lock: if not self._service: return if self._active > 0: logger.debug("Skip unload: %d active operations", self._active) return if self._service.model is not None: before_mb = _rss_mb() logger.info( "Unloading idle TTS model... (rss_before=%.1f MB, active=%d)", before_mb, self._active, ) self._service.unload_model() # Drop the service instance as well to release any lingering refs self._service = None # Force GC and attempt allocator cache cleanup try: gc.collect() finally: if torch is not None: try: if hasattr(torch, "cuda") and torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: logger.debug("cuda.empty_cache() failed", exc_info=True) try: # MPS empty_cache may exist depending on torch version mps = getattr(torch, "mps", None) if mps is not None and hasattr(mps, "empty_cache"): mps.empty_cache() except Exception: logger.debug("mps.empty_cache() failed", exc_info=True) after_mb = _rss_mb() if after_mb >= 0 and before_mb >= 0: logger.info( "Idle unload complete (rss_after=%.1f MB, delta=%.1f MB)", after_mb, after_mb - before_mb, ) self._last_used = time.time() async def get_service(self) -> TTSService: if not self._service or self._service.model is None: await self.load() self._last_used = time.time() return self._service # type: ignore[return-value] async def _inc(self) -> None: async with self._counter_lock: self._active += 1 async def _dec(self) -> None: async with self._counter_lock: self._active = max(0, self._active - 1) self._last_used = time.time() def last_used(self) -> float: return self._last_used def is_loaded(self) -> bool: return bool(self._service and self._service.model is not None) def active(self) -> int: return self._active def using(self): manager = self class _Ctx: async def __aenter__(self): await manager._inc() return manager async def __aexit__(self, exc_type, exc, tb): await manager._dec() return _Ctx()