backend: implement idle TTS model eviction\n\n- Add MODEL_EVICTION_ENABLED, MODEL_IDLE_TIMEOUT_SECONDS, MODEL_IDLE_CHECK_INTERVAL_SECONDS in app/config.py\n- Add ModelManager service to manage TTSService load/unload with usage tracking\n- Add background idle reaper in app/main.py (startup/shutdown hooks)\n- Refactor dialog router to use ModelManager dependency instead of per-request load/unload

This commit is contained in:
Steve White 2025-08-12 16:33:54 -05:00
parent 41f95cdee3
commit cbc164c7a3
4 changed files with 164 additions and 24 deletions

View File

@ -67,6 +67,14 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
# Device configuration # Device configuration
DEVICE = os.getenv("DEVICE", "auto") DEVICE = os.getenv("DEVICE", "auto")
# Model idle eviction configuration
# Enable/disable idle-based model eviction
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
# Unload model after this many seconds of inactivity (0 disables eviction)
MODEL_IDLE_TIMEOUT_SECONDS = int(os.getenv("MODEL_IDLE_TIMEOUT_SECONDS", "900"))
# How often the reaper checks for idleness
MODEL_IDLE_CHECK_INTERVAL_SECONDS = int(os.getenv("MODEL_IDLE_CHECK_INTERVAL_SECONDS", "60"))
# Ensure directories exist # Ensure directories exist
SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True) SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

View File

@ -2,6 +2,10 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pathlib import Path from pathlib import Path
import asyncio
import contextlib
import logging
import time
from app.routers import speakers, dialog # Import the routers from app.routers import speakers, dialog # Import the routers
from app import config from app import config
@ -38,3 +42,39 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio") app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
# Further endpoints for speakers, dialog generation, etc., will be added here. # Further endpoints for speakers, dialog generation, etc., will be added here.
# --- Background task: idle model reaper ---
logger = logging.getLogger("app.model_reaper")
@app.on_event("startup")
async def _start_model_reaper():
from app.services.model_manager import ModelManager
async def reaper():
while True:
try:
await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
if not getattr(config, "MODEL_EVICTION_ENABLED", True):
continue
timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0)
if timeout <= 0:
continue
m = ModelManager.instance()
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout)
await m.unload()
except asyncio.CancelledError:
break
except Exception:
logger.exception("Model reaper encountered an error")
app.state._model_reaper_task = asyncio.create_task(reaper())
@app.on_event("shutdown")
async def _stop_model_reaper():
task = getattr(app.state, "_model_reaper_task", None)
if task:
task.cancel()
with contextlib.suppress(Exception):
await task

View File

@ -9,6 +9,8 @@ from app.services.speaker_service import SpeakerManagementService
from app.services.dialog_processor_service import DialogProcessorService from app.services.dialog_processor_service import DialogProcessorService
from app.services.audio_manipulation_service import AudioManipulationService from app.services.audio_manipulation_service import AudioManipulationService
from app import config from app import config
from typing import AsyncIterator
from app.services.model_manager import ModelManager
router = APIRouter() router = APIRouter()
@ -16,9 +18,12 @@ router = APIRouter()
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init. # These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
# For now, direct instantiation or simple Depends is fine. # For now, direct instantiation or simple Depends is fine.
def get_tts_service(): async def get_tts_service() -> AsyncIterator[TTSService]:
# Consider making device configurable """Dependency that holds a usage token for the duration of the request."""
return TTSService(device="mps") manager = ModelManager.instance()
async with manager.using():
service = await manager.get_service()
yield service
def get_speaker_management_service(): def get_speaker_management_service():
return SpeakerManagementService() return SpeakerManagementService()
@ -32,7 +37,7 @@ def get_dialog_processor_service(
def get_audio_manipulation_service(): def get_audio_manipulation_service():
return AudioManipulationService() return AudioManipulationService()
# --- Helper function to manage TTS model loading/unloading --- # --- Helper imports ---
from app.models.dialog_models import SpeechItem, SilenceItem from app.models.dialog_models import SpeechItem, SilenceItem
from app.services.tts_service import TTSService from app.services.tts_service import TTSService
@ -128,19 +133,7 @@ async def generate_line(
detail=error_detail detail=error_detail
) )
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs): # Removed per-request load/unload in favor of ModelManager idle eviction.
"""Loads TTS model, executes task, then unloads model."""
try:
print("API: Loading TTS model...")
tts_service.load_model()
return await task_function(*args, **kwargs)
except Exception as e:
# Log or handle specific exceptions if needed before re-raising
print(f"API: Error during TTS model lifecycle or task execution: {e}")
raise
finally:
print("API: Unloading TTS model...")
tts_service.unload_model()
async def process_dialog_flow( async def process_dialog_flow(
request: DialogRequest, request: DialogRequest,
@ -274,12 +267,10 @@ async def generate_dialog_endpoint(
- Concatenates all audio segments into a single file. - Concatenates all audio segments into a single file.
- Creates a ZIP archive of all individual segments and the concatenated file. - Creates a ZIP archive of all individual segments and the concatenated file.
""" """
# Wrap the core processing logic with model loading/unloading # Execute core processing; ModelManager dependency keeps the model marked "in use".
return await manage_tts_model_lifecycle( return await process_dialog_flow(
tts_service,
process_dialog_flow,
request=request, request=request,
dialog_processor=dialog_processor, dialog_processor=dialog_processor,
audio_manipulator=audio_manipulator, audio_manipulator=audio_manipulator,
background_tasks=background_tasks background_tasks=background_tasks,
) )

View File

@ -0,0 +1,101 @@
import asyncio
import time
import logging
from typing import Optional
from app import config
from app.services.tts_service import TTSService
logger = logging.getLogger(__name__)
class ModelManager:
_instance: Optional["ModelManager"] = None
def __init__(self):
self._service: Optional[TTSService] = None
self._last_used: float = time.time()
self._active: int = 0
self._lock = asyncio.Lock()
self._counter_lock = asyncio.Lock()
@classmethod
def instance(cls) -> "ModelManager":
if not cls._instance:
cls._instance = cls()
return cls._instance
async def _ensure_service(self) -> None:
if self._service is None:
# Use configured device, default is handled by TTSService itself
device = getattr(config, "DEVICE", "auto")
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
if device == "auto":
try:
import torch
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
except Exception:
device = "cpu"
self._service = TTSService(device=device)
async def load(self) -> None:
async with self._lock:
await self._ensure_service()
if self._service and self._service.model is None:
logger.info("Loading TTS model (device=%s)...", self._service.device)
self._service.load_model()
self._last_used = time.time()
async def unload(self) -> None:
async with self._lock:
if not self._service:
return
if self._active > 0:
logger.debug("Skip unload: %d active operations", self._active)
return
if self._service.model is not None:
logger.info("Unloading idle TTS model...")
self._service.unload_model()
self._last_used = time.time()
async def get_service(self) -> TTSService:
if not self._service or self._service.model is None:
await self.load()
self._last_used = time.time()
return self._service # type: ignore[return-value]
async def _inc(self) -> None:
async with self._counter_lock:
self._active += 1
async def _dec(self) -> None:
async with self._counter_lock:
self._active = max(0, self._active - 1)
self._last_used = time.time()
def last_used(self) -> float:
return self._last_used
def is_loaded(self) -> bool:
return bool(self._service and self._service.model is not None)
def active(self) -> int:
return self._active
def using(self):
manager = self
class _Ctx:
async def __aenter__(self):
await manager._inc()
return manager
async def __aexit__(self, exc_type, exc, tb):
await manager._dec()
return _Ctx()