diff --git a/backend/app/config.py b/backend/app/config.py index be3b33e..60e2d67 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -67,6 +67,14 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port: # Device configuration DEVICE = os.getenv("DEVICE", "auto") +# Model idle eviction configuration +# Enable/disable idle-based model eviction +MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true" +# Unload model after this many seconds of inactivity (0 disables eviction) +MODEL_IDLE_TIMEOUT_SECONDS = int(os.getenv("MODEL_IDLE_TIMEOUT_SECONDS", "900")) +# How often the reaper checks for idleness +MODEL_IDLE_CHECK_INTERVAL_SECONDS = int(os.getenv("MODEL_IDLE_CHECK_INTERVAL_SECONDS", "60")) + # Ensure directories exist SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True) TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) diff --git a/backend/app/main.py b/backend/app/main.py index 7b57297..90d8a79 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,6 +2,10 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pathlib import Path +import asyncio +import contextlib +import logging +import time from app.routers import speakers, dialog # Import the routers from app import config @@ -38,3 +42,39 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True) app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio") # Further endpoints for speakers, dialog generation, etc., will be added here. + +# --- Background task: idle model reaper --- +logger = logging.getLogger("app.model_reaper") + +@app.on_event("startup") +async def _start_model_reaper(): + from app.services.model_manager import ModelManager + + async def reaper(): + while True: + try: + await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS) + if not getattr(config, "MODEL_EVICTION_ENABLED", True): + continue + timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0) + if timeout <= 0: + continue + m = ModelManager.instance() + if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout: + logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout) + await m.unload() + except asyncio.CancelledError: + break + except Exception: + logger.exception("Model reaper encountered an error") + + app.state._model_reaper_task = asyncio.create_task(reaper()) + + +@app.on_event("shutdown") +async def _stop_model_reaper(): + task = getattr(app.state, "_model_reaper_task", None) + if task: + task.cancel() + with contextlib.suppress(Exception): + await task diff --git a/backend/app/routers/dialog.py b/backend/app/routers/dialog.py index adb0508..73dd721 100644 --- a/backend/app/routers/dialog.py +++ b/backend/app/routers/dialog.py @@ -9,6 +9,8 @@ from app.services.speaker_service import SpeakerManagementService from app.services.dialog_processor_service import DialogProcessorService from app.services.audio_manipulation_service import AudioManipulationService from app import config +from typing import AsyncIterator +from app.services.model_manager import ModelManager router = APIRouter() @@ -16,9 +18,12 @@ router = APIRouter() # These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init. # For now, direct instantiation or simple Depends is fine. -def get_tts_service(): - # Consider making device configurable - return TTSService(device="mps") +async def get_tts_service() -> AsyncIterator[TTSService]: + """Dependency that holds a usage token for the duration of the request.""" + manager = ModelManager.instance() + async with manager.using(): + service = await manager.get_service() + yield service def get_speaker_management_service(): return SpeakerManagementService() @@ -32,7 +37,7 @@ def get_dialog_processor_service( def get_audio_manipulation_service(): return AudioManipulationService() -# --- Helper function to manage TTS model loading/unloading --- +# --- Helper imports --- from app.models.dialog_models import SpeechItem, SilenceItem from app.services.tts_service import TTSService @@ -128,19 +133,7 @@ async def generate_line( detail=error_detail ) -async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs): - """Loads TTS model, executes task, then unloads model.""" - try: - print("API: Loading TTS model...") - tts_service.load_model() - return await task_function(*args, **kwargs) - except Exception as e: - # Log or handle specific exceptions if needed before re-raising - print(f"API: Error during TTS model lifecycle or task execution: {e}") - raise - finally: - print("API: Unloading TTS model...") - tts_service.unload_model() +# Removed per-request load/unload in favor of ModelManager idle eviction. async def process_dialog_flow( request: DialogRequest, @@ -274,12 +267,10 @@ async def generate_dialog_endpoint( - Concatenates all audio segments into a single file. - Creates a ZIP archive of all individual segments and the concatenated file. """ - # Wrap the core processing logic with model loading/unloading - return await manage_tts_model_lifecycle( - tts_service, - process_dialog_flow, - request=request, - dialog_processor=dialog_processor, + # Execute core processing; ModelManager dependency keeps the model marked "in use". + return await process_dialog_flow( + request=request, + dialog_processor=dialog_processor, audio_manipulator=audio_manipulator, - background_tasks=background_tasks + background_tasks=background_tasks, ) diff --git a/backend/app/services/model_manager.py b/backend/app/services/model_manager.py new file mode 100644 index 0000000..c071271 --- /dev/null +++ b/backend/app/services/model_manager.py @@ -0,0 +1,101 @@ +import asyncio +import time +import logging +from typing import Optional + +from app import config +from app.services.tts_service import TTSService + +logger = logging.getLogger(__name__) + + +class ModelManager: + _instance: Optional["ModelManager"] = None + + def __init__(self): + self._service: Optional[TTSService] = None + self._last_used: float = time.time() + self._active: int = 0 + self._lock = asyncio.Lock() + self._counter_lock = asyncio.Lock() + + @classmethod + def instance(cls) -> "ModelManager": + if not cls._instance: + cls._instance = cls() + return cls._instance + + async def _ensure_service(self) -> None: + if self._service is None: + # Use configured device, default is handled by TTSService itself + device = getattr(config, "DEVICE", "auto") + # TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu + if device == "auto": + try: + import torch + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + except Exception: + device = "cpu" + self._service = TTSService(device=device) + + async def load(self) -> None: + async with self._lock: + await self._ensure_service() + if self._service and self._service.model is None: + logger.info("Loading TTS model (device=%s)...", self._service.device) + self._service.load_model() + self._last_used = time.time() + + async def unload(self) -> None: + async with self._lock: + if not self._service: + return + if self._active > 0: + logger.debug("Skip unload: %d active operations", self._active) + return + if self._service.model is not None: + logger.info("Unloading idle TTS model...") + self._service.unload_model() + self._last_used = time.time() + + async def get_service(self) -> TTSService: + if not self._service or self._service.model is None: + await self.load() + self._last_used = time.time() + return self._service # type: ignore[return-value] + + async def _inc(self) -> None: + async with self._counter_lock: + self._active += 1 + + async def _dec(self) -> None: + async with self._counter_lock: + self._active = max(0, self._active - 1) + self._last_used = time.time() + + def last_used(self) -> float: + return self._last_used + + def is_loaded(self) -> bool: + return bool(self._service and self._service.model is not None) + + def active(self) -> int: + return self._active + + def using(self): + manager = self + + class _Ctx: + async def __aenter__(self): + await manager._inc() + return manager + + async def __aexit__(self, exc_type, exc, tb): + await manager._dec() + + return _Ctx()