102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import asyncio
|
|
import time
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from app import config
|
|
from app.services.tts_service import TTSService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelManager:
|
|
_instance: Optional["ModelManager"] = None
|
|
|
|
def __init__(self):
|
|
self._service: Optional[TTSService] = None
|
|
self._last_used: float = time.time()
|
|
self._active: int = 0
|
|
self._lock = asyncio.Lock()
|
|
self._counter_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
def instance(cls) -> "ModelManager":
|
|
if not cls._instance:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
async def _ensure_service(self) -> None:
|
|
if self._service is None:
|
|
# Use configured device, default is handled by TTSService itself
|
|
device = getattr(config, "DEVICE", "auto")
|
|
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
|
|
if device == "auto":
|
|
try:
|
|
import torch
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
device = "mps"
|
|
elif torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
except Exception:
|
|
device = "cpu"
|
|
self._service = TTSService(device=device)
|
|
|
|
async def load(self) -> None:
|
|
async with self._lock:
|
|
await self._ensure_service()
|
|
if self._service and self._service.model is None:
|
|
logger.info("Loading TTS model (device=%s)...", self._service.device)
|
|
self._service.load_model()
|
|
self._last_used = time.time()
|
|
|
|
async def unload(self) -> None:
|
|
async with self._lock:
|
|
if not self._service:
|
|
return
|
|
if self._active > 0:
|
|
logger.debug("Skip unload: %d active operations", self._active)
|
|
return
|
|
if self._service.model is not None:
|
|
logger.info("Unloading idle TTS model...")
|
|
self._service.unload_model()
|
|
self._last_used = time.time()
|
|
|
|
async def get_service(self) -> TTSService:
|
|
if not self._service or self._service.model is None:
|
|
await self.load()
|
|
self._last_used = time.time()
|
|
return self._service # type: ignore[return-value]
|
|
|
|
async def _inc(self) -> None:
|
|
async with self._counter_lock:
|
|
self._active += 1
|
|
|
|
async def _dec(self) -> None:
|
|
async with self._counter_lock:
|
|
self._active = max(0, self._active - 1)
|
|
self._last_used = time.time()
|
|
|
|
def last_used(self) -> float:
|
|
return self._last_used
|
|
|
|
def is_loaded(self) -> bool:
|
|
return bool(self._service and self._service.model is not None)
|
|
|
|
def active(self) -> int:
|
|
return self._active
|
|
|
|
def using(self):
|
|
manager = self
|
|
|
|
class _Ctx:
|
|
async def __aenter__(self):
|
|
await manager._inc()
|
|
return manager
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await manager._dec()
|
|
|
|
return _Ctx()
|