import asyncio import time import logging from typing import Optional from app import config from app.services.tts_service import TTSService logger = logging.getLogger(__name__) class ModelManager: _instance: Optional["ModelManager"] = None def __init__(self): self._service: Optional[TTSService] = None self._last_used: float = time.time() self._active: int = 0 self._lock = asyncio.Lock() self._counter_lock = asyncio.Lock() @classmethod def instance(cls) -> "ModelManager": if not cls._instance: cls._instance = cls() return cls._instance async def _ensure_service(self) -> None: if self._service is None: # Use configured device, default is handled by TTSService itself device = getattr(config, "DEVICE", "auto") # TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu if device == "auto": try: import torch if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" except Exception: device = "cpu" self._service = TTSService(device=device) async def load(self) -> None: async with self._lock: await self._ensure_service() if self._service and self._service.model is None: logger.info("Loading TTS model (device=%s)...", self._service.device) self._service.load_model() self._last_used = time.time() async def unload(self) -> None: async with self._lock: if not self._service: return if self._active > 0: logger.debug("Skip unload: %d active operations", self._active) return if self._service.model is not None: logger.info("Unloading idle TTS model...") self._service.unload_model() self._last_used = time.time() async def get_service(self) -> TTSService: if not self._service or self._service.model is None: await self.load() self._last_used = time.time() return self._service # type: ignore[return-value] async def _inc(self) -> None: async with self._counter_lock: self._active += 1 async def _dec(self) -> None: async with self._counter_lock: self._active = max(0, self._active - 1) self._last_used = time.time() def last_used(self) -> float: return self._last_used def is_loaded(self) -> bool: return bool(self._service and self._service.model is not None) def active(self) -> int: return self._active def using(self): manager = self class _Ctx: async def __aenter__(self): await manager._inc() return manager async def __aexit__(self, exc_type, exc, tb): await manager._dec() return _Ctx()