backend: implement idle TTS model eviction\n\n- Add MODEL_EVICTION_ENABLED, MODEL_IDLE_TIMEOUT_SECONDS, MODEL_IDLE_CHECK_INTERVAL_SECONDS in app/config.py\n- Add ModelManager service to manage TTSService load/unload with usage tracking\n- Add background idle reaper in app/main.py (startup/shutdown hooks)\n- Refactor dialog router to use ModelManager dependency instead of per-request load/unload
This commit is contained in:
parent
41f95cdee3
commit
cbc164c7a3
|
@ -67,6 +67,14 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
|
||||||
# Device configuration
|
# Device configuration
|
||||||
DEVICE = os.getenv("DEVICE", "auto")
|
DEVICE = os.getenv("DEVICE", "auto")
|
||||||
|
|
||||||
|
# Model idle eviction configuration
|
||||||
|
# Enable/disable idle-based model eviction
|
||||||
|
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
|
||||||
|
# Unload model after this many seconds of inactivity (0 disables eviction)
|
||||||
|
MODEL_IDLE_TIMEOUT_SECONDS = int(os.getenv("MODEL_IDLE_TIMEOUT_SECONDS", "900"))
|
||||||
|
# How often the reaper checks for idleness
|
||||||
|
MODEL_IDLE_CHECK_INTERVAL_SECONDS = int(os.getenv("MODEL_IDLE_CHECK_INTERVAL_SECONDS", "60"))
|
||||||
|
|
||||||
# Ensure directories exist
|
# Ensure directories exist
|
||||||
SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
TTS_TEMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -2,6 +2,10 @@ from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
from app.routers import speakers, dialog # Import the routers
|
from app.routers import speakers, dialog # Import the routers
|
||||||
from app import config
|
from app import config
|
||||||
|
|
||||||
|
@ -38,3 +42,39 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
||||||
|
|
||||||
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
||||||
|
|
||||||
|
# --- Background task: idle model reaper ---
|
||||||
|
logger = logging.getLogger("app.model_reaper")
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def _start_model_reaper():
|
||||||
|
from app.services.model_manager import ModelManager
|
||||||
|
|
||||||
|
async def reaper():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
|
||||||
|
if not getattr(config, "MODEL_EVICTION_ENABLED", True):
|
||||||
|
continue
|
||||||
|
timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0)
|
||||||
|
if timeout <= 0:
|
||||||
|
continue
|
||||||
|
m = ModelManager.instance()
|
||||||
|
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
|
||||||
|
logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout)
|
||||||
|
await m.unload()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Model reaper encountered an error")
|
||||||
|
|
||||||
|
app.state._model_reaper_task = asyncio.create_task(reaper())
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def _stop_model_reaper():
|
||||||
|
task = getattr(app.state, "_model_reaper_task", None)
|
||||||
|
if task:
|
||||||
|
task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await task
|
||||||
|
|
|
@ -9,6 +9,8 @@ from app.services.speaker_service import SpeakerManagementService
|
||||||
from app.services.dialog_processor_service import DialogProcessorService
|
from app.services.dialog_processor_service import DialogProcessorService
|
||||||
from app.services.audio_manipulation_service import AudioManipulationService
|
from app.services.audio_manipulation_service import AudioManipulationService
|
||||||
from app import config
|
from app import config
|
||||||
|
from typing import AsyncIterator
|
||||||
|
from app.services.model_manager import ModelManager
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -16,9 +18,12 @@ router = APIRouter()
|
||||||
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
|
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
|
||||||
# For now, direct instantiation or simple Depends is fine.
|
# For now, direct instantiation or simple Depends is fine.
|
||||||
|
|
||||||
def get_tts_service():
|
async def get_tts_service() -> AsyncIterator[TTSService]:
|
||||||
# Consider making device configurable
|
"""Dependency that holds a usage token for the duration of the request."""
|
||||||
return TTSService(device="mps")
|
manager = ModelManager.instance()
|
||||||
|
async with manager.using():
|
||||||
|
service = await manager.get_service()
|
||||||
|
yield service
|
||||||
|
|
||||||
def get_speaker_management_service():
|
def get_speaker_management_service():
|
||||||
return SpeakerManagementService()
|
return SpeakerManagementService()
|
||||||
|
@ -32,7 +37,7 @@ def get_dialog_processor_service(
|
||||||
def get_audio_manipulation_service():
|
def get_audio_manipulation_service():
|
||||||
return AudioManipulationService()
|
return AudioManipulationService()
|
||||||
|
|
||||||
# --- Helper function to manage TTS model loading/unloading ---
|
# --- Helper imports ---
|
||||||
|
|
||||||
from app.models.dialog_models import SpeechItem, SilenceItem
|
from app.models.dialog_models import SpeechItem, SilenceItem
|
||||||
from app.services.tts_service import TTSService
|
from app.services.tts_service import TTSService
|
||||||
|
@ -128,19 +133,7 @@ async def generate_line(
|
||||||
detail=error_detail
|
detail=error_detail
|
||||||
)
|
)
|
||||||
|
|
||||||
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
# Removed per-request load/unload in favor of ModelManager idle eviction.
|
||||||
"""Loads TTS model, executes task, then unloads model."""
|
|
||||||
try:
|
|
||||||
print("API: Loading TTS model...")
|
|
||||||
tts_service.load_model()
|
|
||||||
return await task_function(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
# Log or handle specific exceptions if needed before re-raising
|
|
||||||
print(f"API: Error during TTS model lifecycle or task execution: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
print("API: Unloading TTS model...")
|
|
||||||
tts_service.unload_model()
|
|
||||||
|
|
||||||
async def process_dialog_flow(
|
async def process_dialog_flow(
|
||||||
request: DialogRequest,
|
request: DialogRequest,
|
||||||
|
@ -274,12 +267,10 @@ async def generate_dialog_endpoint(
|
||||||
- Concatenates all audio segments into a single file.
|
- Concatenates all audio segments into a single file.
|
||||||
- Creates a ZIP archive of all individual segments and the concatenated file.
|
- Creates a ZIP archive of all individual segments and the concatenated file.
|
||||||
"""
|
"""
|
||||||
# Wrap the core processing logic with model loading/unloading
|
# Execute core processing; ModelManager dependency keeps the model marked "in use".
|
||||||
return await manage_tts_model_lifecycle(
|
return await process_dialog_flow(
|
||||||
tts_service,
|
|
||||||
process_dialog_flow,
|
|
||||||
request=request,
|
request=request,
|
||||||
dialog_processor=dialog_processor,
|
dialog_processor=dialog_processor,
|
||||||
audio_manipulator=audio_manipulator,
|
audio_manipulator=audio_manipulator,
|
||||||
background_tasks=background_tasks
|
background_tasks=background_tasks,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app import config
|
||||||
|
from app.services.tts_service import TTSService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
_instance: Optional["ModelManager"] = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._service: Optional[TTSService] = None
|
||||||
|
self._last_used: float = time.time()
|
||||||
|
self._active: int = 0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._counter_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instance(cls) -> "ModelManager":
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
async def _ensure_service(self) -> None:
|
||||||
|
if self._service is None:
|
||||||
|
# Use configured device, default is handled by TTSService itself
|
||||||
|
device = getattr(config, "DEVICE", "auto")
|
||||||
|
# TTSService presently expects explicit device like "mps"/"cpu"/"cuda"; map "auto" to "mps" on Mac otherwise cpu
|
||||||
|
if device == "auto":
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
except Exception:
|
||||||
|
device = "cpu"
|
||||||
|
self._service = TTSService(device=device)
|
||||||
|
|
||||||
|
async def load(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
await self._ensure_service()
|
||||||
|
if self._service and self._service.model is None:
|
||||||
|
logger.info("Loading TTS model (device=%s)...", self._service.device)
|
||||||
|
self._service.load_model()
|
||||||
|
self._last_used = time.time()
|
||||||
|
|
||||||
|
async def unload(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if not self._service:
|
||||||
|
return
|
||||||
|
if self._active > 0:
|
||||||
|
logger.debug("Skip unload: %d active operations", self._active)
|
||||||
|
return
|
||||||
|
if self._service.model is not None:
|
||||||
|
logger.info("Unloading idle TTS model...")
|
||||||
|
self._service.unload_model()
|
||||||
|
self._last_used = time.time()
|
||||||
|
|
||||||
|
async def get_service(self) -> TTSService:
|
||||||
|
if not self._service or self._service.model is None:
|
||||||
|
await self.load()
|
||||||
|
self._last_used = time.time()
|
||||||
|
return self._service # type: ignore[return-value]
|
||||||
|
|
||||||
|
async def _inc(self) -> None:
|
||||||
|
async with self._counter_lock:
|
||||||
|
self._active += 1
|
||||||
|
|
||||||
|
async def _dec(self) -> None:
|
||||||
|
async with self._counter_lock:
|
||||||
|
self._active = max(0, self._active - 1)
|
||||||
|
self._last_used = time.time()
|
||||||
|
|
||||||
|
def last_used(self) -> float:
|
||||||
|
return self._last_used
|
||||||
|
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return bool(self._service and self._service.model is not None)
|
||||||
|
|
||||||
|
def active(self) -> int:
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
def using(self):
|
||||||
|
manager = self
|
||||||
|
|
||||||
|
class _Ctx:
|
||||||
|
async def __aenter__(self):
|
||||||
|
await manager._inc()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
await manager._dec()
|
||||||
|
|
||||||
|
return _Ctx()
|
Loading…
Reference in New Issue