171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
import asyncio
|
|
import time
|
|
import logging
|
|
from typing import Optional
|
|
import gc
|
|
import os
|
|
|
|
_proc = None
|
|
try:
|
|
import psutil # type: ignore
|
|
_proc = psutil.Process(os.getpid())
|
|
except Exception:
|
|
psutil = None # type: ignore
|
|
|
|
def _rss_mb() -> float:
|
|
"""Return current process RSS in MB, or -1.0 if unavailable."""
|
|
global _proc
|
|
try:
|
|
if _proc is None and psutil is not None:
|
|
_proc = psutil.Process(os.getpid())
|
|
if _proc is not None:
|
|
return _proc.memory_info().rss / (1024 * 1024)
|
|
except Exception:
|
|
return -1.0
|
|
return -1.0
|
|
|
|
try:
|
|
import torch # Optional; used for cache cleanup metrics
|
|
except Exception: # pragma: no cover - torch may not be present in some envs
|
|
torch = None # type: ignore
|
|
|
|
from app import config
|
|
from app.services.tts_service import TTSService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelManager:
|
|
_instance: Optional["ModelManager"] = None
|
|
|
|
def __init__(self):
|
|
self._service: Optional[TTSService] = None
|
|
self._last_used: float = time.time()
|
|
self._active: int = 0
|
|
self._lock = asyncio.Lock()
|
|
self._counter_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
def instance(cls) -> "ModelManager":
|
|
if not cls._instance:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
async def _ensure_service(self) -> None:
|
|
if self._service is None:
|
|
# Use configured device, default is handled by TTSService itself
|
|
device = getattr(config, "DEVICE", "auto")
|
|
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
|
|
if device == "auto":
|
|
try:
|
|
import torch
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
device = "mps"
|
|
elif torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
except Exception:
|
|
device = "cpu"
|
|
self._service = TTSService(device=device)
|
|
|
|
async def load(self) -> None:
|
|
async with self._lock:
|
|
await self._ensure_service()
|
|
if self._service and self._service.model is None:
|
|
before_mb = _rss_mb()
|
|
logger.info(
|
|
"Loading TTS model (device=%s)... (rss_before=%.1f MB)",
|
|
self._service.device,
|
|
before_mb,
|
|
)
|
|
self._service.load_model()
|
|
after_mb = _rss_mb()
|
|
if after_mb >= 0 and before_mb >= 0:
|
|
logger.info(
|
|
"TTS model loaded (rss_after=%.1f MB, delta=%.1f MB)",
|
|
after_mb,
|
|
after_mb - before_mb,
|
|
)
|
|
self._last_used = time.time()
|
|
|
|
async def unload(self) -> None:
|
|
async with self._lock:
|
|
if not self._service:
|
|
return
|
|
if self._active > 0:
|
|
logger.debug("Skip unload: %d active operations", self._active)
|
|
return
|
|
if self._service.model is not None:
|
|
before_mb = _rss_mb()
|
|
logger.info(
|
|
"Unloading idle TTS model... (rss_before=%.1f MB, active=%d)",
|
|
before_mb,
|
|
self._active,
|
|
)
|
|
self._service.unload_model()
|
|
# Drop the service instance as well to release any lingering refs
|
|
self._service = None
|
|
# Force GC and attempt allocator cache cleanup
|
|
try:
|
|
gc.collect()
|
|
finally:
|
|
if torch is not None:
|
|
try:
|
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
except Exception:
|
|
logger.debug("cuda.empty_cache() failed", exc_info=True)
|
|
try:
|
|
# MPS empty_cache may exist depending on torch version
|
|
mps = getattr(torch, "mps", None)
|
|
if mps is not None and hasattr(mps, "empty_cache"):
|
|
mps.empty_cache()
|
|
except Exception:
|
|
logger.debug("mps.empty_cache() failed", exc_info=True)
|
|
after_mb = _rss_mb()
|
|
if after_mb >= 0 and before_mb >= 0:
|
|
logger.info(
|
|
"Idle unload complete (rss_after=%.1f MB, delta=%.1f MB)",
|
|
after_mb,
|
|
after_mb - before_mb,
|
|
)
|
|
self._last_used = time.time()
|
|
|
|
async def get_service(self) -> TTSService:
|
|
if not self._service or self._service.model is None:
|
|
await self.load()
|
|
self._last_used = time.time()
|
|
return self._service # type: ignore[return-value]
|
|
|
|
async def _inc(self) -> None:
|
|
async with self._counter_lock:
|
|
self._active += 1
|
|
|
|
async def _dec(self) -> None:
|
|
async with self._counter_lock:
|
|
self._active = max(0, self._active - 1)
|
|
self._last_used = time.time()
|
|
|
|
def last_used(self) -> float:
|
|
return self._last_used
|
|
|
|
def is_loaded(self) -> bool:
|
|
return bool(self._service and self._service.model is not None)
|
|
|
|
def active(self) -> int:
|
|
return self._active
|
|
|
|
def using(self):
|
|
manager = self
|
|
|
|
class _Ctx:
|
|
async def __aenter__(self):
|
|
await manager._inc()
|
|
return manager
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await manager._dec()
|
|
|
|
return _Ctx()
|