backend: implement idle TTS model eviction\n\n- Add MODEL_EVICTION_ENABLED, MODEL_IDLE_TIMEOUT_SECONDS, MODEL_IDLE_CHECK_INTERVAL_SECONDS in app/config.py\n- Add ModelManager service to manage TTSService load/unload with usage tracking\n- Add background idle reaper in app/main.py (startup/shutdown hooks)\n- Refactor dialog router to use ModelManager dependency instead of per-request load/unload
This commit is contained in:
parent
41f95cdee3
commit
cbc164c7a3
|
@ -67,6 +67,14 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
|
|||
# Device configuration
|
||||
DEVICE = os.getenv("DEVICE", "auto")
|
||||
|
||||
# Model idle eviction configuration
|
||||
# Enable/disable idle-based model eviction
|
||||
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
|
||||
# Unload model after this many seconds of inactivity (0 disables eviction)
|
||||
MODEL_IDLE_TIMEOUT_SECONDS = int(os.getenv("MODEL_IDLE_TIMEOUT_SECONDS", "900"))
|
||||
# How often the reaper checks for idleness
|
||||
MODEL_IDLE_CHECK_INTERVAL_SECONDS = int(os.getenv("MODEL_IDLE_CHECK_INTERVAL_SECONDS", "60"))
|
||||
|
||||
# Ensure directories exist
|
||||
SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -2,6 +2,10 @@ from fastapi import FastAPI
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from app.routers import speakers, dialog # Import the routers
|
||||
from app import config
|
||||
|
||||
|
@ -38,3 +42,39 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
|||
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
||||
|
||||
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
||||
|
||||
# --- Background task: idle model reaper ---
|
||||
logger = logging.getLogger("app.model_reaper")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _start_model_reaper():
|
||||
from app.services.model_manager import ModelManager
|
||||
|
||||
async def reaper():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
|
||||
if not getattr(config, "MODEL_EVICTION_ENABLED", True):
|
||||
continue
|
||||
timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0)
|
||||
if timeout <= 0:
|
||||
continue
|
||||
m = ModelManager.instance()
|
||||
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
|
||||
logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout)
|
||||
await m.unload()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Model reaper encountered an error")
|
||||
|
||||
app.state._model_reaper_task = asyncio.create_task(reaper())
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _stop_model_reaper():
|
||||
task = getattr(app.state, "_model_reaper_task", None)
|
||||
if task:
|
||||
task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await task
|
||||
|
|
|
@ -9,6 +9,8 @@ from app.services.speaker_service import SpeakerManagementService
|
|||
from app.services.dialog_processor_service import DialogProcessorService
|
||||
from app.services.audio_manipulation_service import AudioManipulationService
|
||||
from app import config
|
||||
from typing import AsyncIterator
|
||||
from app.services.model_manager import ModelManager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -16,9 +18,12 @@ router = APIRouter()
|
|||
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
|
||||
# For now, direct instantiation or simple Depends is fine.
|
||||
|
||||
def get_tts_service():
|
||||
# Consider making device configurable
|
||||
return TTSService(device="mps")
|
||||
async def get_tts_service() -> AsyncIterator[TTSService]:
|
||||
"""Dependency that holds a usage token for the duration of the request."""
|
||||
manager = ModelManager.instance()
|
||||
async with manager.using():
|
||||
service = await manager.get_service()
|
||||
yield service
|
||||
|
||||
def get_speaker_management_service():
|
||||
return SpeakerManagementService()
|
||||
|
@ -32,7 +37,7 @@ def get_dialog_processor_service(
|
|||
def get_audio_manipulation_service():
|
||||
return AudioManipulationService()
|
||||
|
||||
# --- Helper function to manage TTS model loading/unloading ---
|
||||
# --- Helper imports ---
|
||||
|
||||
from app.models.dialog_models import SpeechItem, SilenceItem
|
||||
from app.services.tts_service import TTSService
|
||||
|
@ -128,19 +133,7 @@ async def generate_line(
|
|||
detail=error_detail
|
||||
)
|
||||
|
||||
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
||||
"""Loads TTS model, executes task, then unloads model."""
|
||||
try:
|
||||
print("API: Loading TTS model...")
|
||||
tts_service.load_model()
|
||||
return await task_function(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Log or handle specific exceptions if needed before re-raising
|
||||
print(f"API: Error during TTS model lifecycle or task execution: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("API: Unloading TTS model...")
|
||||
tts_service.unload_model()
|
||||
# Removed per-request load/unload in favor of ModelManager idle eviction.
|
||||
|
||||
async def process_dialog_flow(
|
||||
request: DialogRequest,
|
||||
|
@ -274,12 +267,10 @@ async def generate_dialog_endpoint(
|
|||
- Concatenates all audio segments into a single file.
|
||||
- Creates a ZIP archive of all individual segments and the concatenated file.
|
||||
"""
|
||||
# Wrap the core processing logic with model loading/unloading
|
||||
return await manage_tts_model_lifecycle(
|
||||
tts_service,
|
||||
process_dialog_flow,
|
||||
request=request,
|
||||
dialog_processor=dialog_processor,
|
||||
# Execute core processing; ModelManager dependency keeps the model marked "in use".
|
||||
return await process_dialog_flow(
|
||||
request=request,
|
||||
dialog_processor=dialog_processor,
|
||||
audio_manipulator=audio_manipulator,
|
||||
background_tasks=background_tasks
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app import config
|
||||
from app.services.tts_service import TTSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
_instance: Optional["ModelManager"] = None
|
||||
|
||||
def __init__(self):
|
||||
self._service: Optional[TTSService] = None
|
||||
self._last_used: float = time.time()
|
||||
self._active: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._counter_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> "ModelManager":
|
||||
if not cls._instance:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def _ensure_service(self) -> None:
|
||||
if self._service is None:
|
||||
# Use configured device, default is handled by TTSService itself
|
||||
device = getattr(config, "DEVICE", "auto")
|
||||
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
|
||||
if device == "auto":
|
||||
try:
|
||||
import torch
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
except Exception:
|
||||
device = "cpu"
|
||||
self._service = TTSService(device=device)
|
||||
|
||||
async def load(self) -> None:
|
||||
async with self._lock:
|
||||
await self._ensure_service()
|
||||
if self._service and self._service.model is None:
|
||||
logger.info("Loading TTS model (device=%s)...", self._service.device)
|
||||
self._service.load_model()
|
||||
self._last_used = time.time()
|
||||
|
||||
async def unload(self) -> None:
|
||||
async with self._lock:
|
||||
if not self._service:
|
||||
return
|
||||
if self._active > 0:
|
||||
logger.debug("Skip unload: %d active operations", self._active)
|
||||
return
|
||||
if self._service.model is not None:
|
||||
logger.info("Unloading idle TTS model...")
|
||||
self._service.unload_model()
|
||||
self._last_used = time.time()
|
||||
|
||||
async def get_service(self) -> TTSService:
|
||||
if not self._service or self._service.model is None:
|
||||
await self.load()
|
||||
self._last_used = time.time()
|
||||
return self._service # type: ignore[return-value]
|
||||
|
||||
async def _inc(self) -> None:
|
||||
async with self._counter_lock:
|
||||
self._active += 1
|
||||
|
||||
async def _dec(self) -> None:
|
||||
async with self._counter_lock:
|
||||
self._active = max(0, self._active - 1)
|
||||
self._last_used = time.time()
|
||||
|
||||
def last_used(self) -> float:
|
||||
return self._last_used
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return bool(self._service and self._service.model is not None)
|
||||
|
||||
def active(self) -> int:
|
||||
return self._active
|
||||
|
||||
def using(self):
|
||||
manager = self
|
||||
|
||||
class _Ctx:
|
||||
async def __aenter__(self):
|
||||
await manager._inc()
|
||||
return manager
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await manager._dec()
|
||||
|
||||
return _Ctx()
|
Loading…
Reference in New Issue