81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
from fastapi import FastAPI
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pathlib import Path
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import time
|
|
from app.routers import speakers, dialog # Import the routers
|
|
from app import config
|
|
|
|
app = FastAPI(
|
|
title="Chatterbox TTS API",
|
|
description="API for generating TTS dialogs using Chatterbox TTS.",
|
|
version="0.1.0",
|
|
)
|
|
|
|
# CORS Middleware configuration
|
|
# For development, we'll allow all origins
|
|
# In production, you should restrict this to specific origins
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.CORS_ORIGINS,
|
|
allow_credentials=False,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
expose_headers=["*"]
|
|
)
|
|
|
|
# Include routers
|
|
app.include_router(speakers.router, prefix="/api/speakers", tags=["Speakers"])
|
|
app.include_router(dialog.router, prefix="/api/dialog", tags=["Dialog Generation"])
|
|
|
|
@app.get("/")
|
|
async def read_root():
|
|
return {"message": "Welcome to the Chatterbox TTS API!"}
|
|
|
|
# Ensure the directory for serving generated audio exists
|
|
config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Mount StaticFiles to serve generated dialogs
|
|
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
|
|
|
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
|
|
|
# --- Background task: idle model reaper ---
|
|
logger = logging.getLogger("app.model_reaper")
|
|
|
|
@app.on_event("startup")
|
|
async def _start_model_reaper():
|
|
from app.services.model_manager import ModelManager
|
|
|
|
async def reaper():
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(config.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
|
|
if not getattr(config, "MODEL_EVICTION_ENABLED", True):
|
|
continue
|
|
timeout = getattr(config, "MODEL_IDLE_TIMEOUT_SECONDS", 0)
|
|
if timeout <= 0:
|
|
continue
|
|
m = ModelManager.instance()
|
|
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
|
|
logger.info("Idle timeout reached (%.0fs). Unloading model...", timeout)
|
|
await m.unload()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception:
|
|
logger.exception("Model reaper encountered an error")
|
|
|
|
app.state._model_reaper_task = asyncio.create_task(reaper())
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def _stop_model_reaper():
|
|
task = getattr(app.state, "_model_reaper_task", None)
|
|
if task:
|
|
task.cancel()
|
|
with contextlib.suppress(Exception):
|
|
await task
|