backend: implement idle TTS model eviction\n\n- Add MODEL_EVICTION_ENABLED, MODEL_IDLE_TIMEOUT_SECONDS, MODEL_IDLE_CHECK_INTERVAL_SECONDS in app/config.py\n- Add ModelManager service to manage TTSService load/unload with usage tracking\n- Add background idle reaper in app/main.py (startup/shutdown hooks)\n- Refactor dialog router to use ModelManager dependency instead of per-request load/unload

This commit is contained in:
Steve White 2025-08-12 16:33:54 -05:00
parent 41f95cdee3
commit cbc164c7a3
4 changed files with 164 additions and 24 deletions

View File

@ -67,6 +67,14 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
# Device configuration
DEVICE = os.getenv("DEVICE", "auto")
# Model idle eviction configuration
# Enable/disable idle-based model eviction
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
# Unload model after this many seconds of inactivity (0 disables eviction)
MODEL_IDLE_TIMEOUT_SECONDS = int(os.getenv("MODEL_IDLE_TIMEOUT_SECONDS", "900"))
# How often the reaper checks for idleness
MODEL_IDLE_CHECK_INTERVAL_SECONDS = int(os.getenv("MODEL_IDLE_CHECK_INTERVAL_SECONDS", "60"))
# Ensure directories exist
SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

View File

@ -2,6 +2,10 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pathlib import Path
import asyncio
import contextlib
import logging
import time
from app.routers import speakers, dialog # Import the routers
from app import config
@ -38,3 +42,39 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
# Further endpoints for speakers, dialog generation, etc., will be added here.
# --- Background task: idle model reaper ---
logger = logging.getLogger("app.model_reaper")
@app.on_event("startup")
async def _start_model_reaper():
from app.services.model_manager import ModelManager
async def reaper():
while True:
try:
await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
if not getattr(config, "MODEL_EVICTION_ENABLED", True):
continue
timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0)
if timeout <= 0:
continue
m = ModelManager.instance()
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout)
await m.unload()
except asyncio.CancelledError:
break
except Exception:
logger.exception("Model reaper encountered an error")
app.state._model_reaper_task = asyncio.create_task(reaper())
@app.on_event("shutdown")
async def _stop_model_reaper():
task = getattr(app.state, "_model_reaper_task", None)
if task:
task.cancel()
with contextlib.suppress(Exception):
await task

View File

@ -9,6 +9,8 @@ from app.services.speaker_service import SpeakerManagementService
from app.services.dialog_processor_service import DialogProcessorService
from app.services.audio_manipulation_service import AudioManipulationService
from app import config
from typing import AsyncIterator
from app.services.model_manager import ModelManager
router = APIRouter()
@ -16,9 +18,12 @@ router = APIRouter()
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
# For now, direct instantiation or simple Depends is fine.
def get_tts_service():
# Consider making device configurable
return TTSService(device="mps")
async def get_tts_service() -> AsyncIterator[TTSService]:
"""Dependency that holds a usage token for the duration of the request."""
manager = ModelManager.instance()
async with manager.using():
service = await manager.get_service()
yield service
def get_speaker_management_service():
return SpeakerManagementService()
@ -32,7 +37,7 @@ def get_dialog_processor_service(
def get_audio_manipulation_service():
return AudioManipulationService()
# --- Helper function to manage TTS model loading/unloading ---
# --- Helper imports ---
from app.models.dialog_models import SpeechItem, SilenceItem
from app.services.tts_service import TTSService
@ -128,19 +133,7 @@ async def generate_line(
detail=error_detail
)
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
"""Loads TTS model, executes task, then unloads model."""
try:
print("API: Loading TTS model...")
tts_service.load_model()
return await task_function(*args, **kwargs)
except Exception as e:
# Log or handle specific exceptions if needed before re-raising
print(f"API: Error during TTS model lifecycle or task execution: {e}")
raise
finally:
print("API: Unloading TTS model...")
tts_service.unload_model()
# Removed per-request load/unload in favor of ModelManager idle eviction.
async def process_dialog_flow(
request: DialogRequest,
@ -274,12 +267,10 @@ async def generate_dialog_endpoint(
- Concatenates all audio segments into a single file.
- Creates a ZIP archive of all individual segments and the concatenated file.
"""
# Wrap the core processing logic with model loading/unloading
return await manage_tts_model_lifecycle(
tts_service,
process_dialog_flow,
request=request,
dialog_processor=dialog_processor,
# Execute core processing; ModelManager dependency keeps the model marked "in use".
return await process_dialog_flow(
request=request,
dialog_processor=dialog_processor,
audio_manipulator=audio_manipulator,
background_tasks=background_tasks
background_tasks=background_tasks,
)

View File

@ -0,0 +1,101 @@
import asyncio
import time
import logging
from typing import Optional
from app import config
from app.services.tts_service import TTSService
logger = logging.getLogger(__name__)
class ModelManager:
_instance: Optional["ModelManager"] = None
def __init__(self):
self._service: Optional[TTSService] = None
self._last_used: float = time.time()
self._active: int = 0
self._lock = asyncio.Lock()
self._counter_lock = asyncio.Lock()
@classmethod
def instance(cls) -> "ModelManager":
if not cls._instance:
cls._instance = cls()
return cls._instance
async def _ensure_service(self) -> None:
if self._service is None:
# Use configured device, default is handled by TTSService itself
device = getattr(config, "DEVICE", "auto")
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
if device == "auto":
try:
import torch
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
except Exception:
device = "cpu"
self._service = TTSService(device=device)
async def load(self) -> None:
async with self._lock:
await self._ensure_service()
if self._service and self._service.model is None:
logger.info("Loading TTS model (device=%s)...", self._service.device)
self._service.load_model()
self._last_used = time.time()
async def unload(self) -> None:
async with self._lock:
if not self._service:
return
if self._active > 0:
logger.debug("Skip unload: %d active operations", self._active)
return
if self._service.model is not None:
logger.info("Unloading idle TTS model...")
self._service.unload_model()
self._last_used = time.time()
async def get_service(self) -> TTSService:
if not self._service or self._service.model is None:
await self.load()
self._last_used = time.time()
return self._service # type: ignore[return-value]
async def _inc(self) -> None:
async with self._counter_lock:
self._active += 1
async def _dec(self) -> None:
async with self._counter_lock:
self._active = max(0, self._active - 1)
self._last_used = time.time()
def last_used(self) -> float:
return self._last_used
def is_loaded(self) -> bool:
return bool(self._service and self._service.model is not None)
def active(self) -> int:
return self._active
def using(self):
manager = self
class _Ctx:
async def __aenter__(self):
await manager._inc()
return manager
async def __aexit__(self, exc_type, exc, tb):
await manager._dec()
return _Ctx()