chatterbox-ui/.note/unload_model_plan.md

6.9 KiB

Unload Model on Idle: Implementation Plan

Goals

  • Automatically unload large TTS model(s) when idle to reduce RAM/VRAM usage.
  • Lazy-load on demand without breaking API semantics.
  • Configurable timeout and safety controls.

Requirements

  • Config-driven idle timeout and poll interval.
  • Thread-/async-safe across concurrent requests.
  • No unload while an inference is in progress.
  • Clear logs and metrics for load/unload events.

Configuration

File: backend/app/config.py

  • Add:
    • MODEL_IDLE_TIMEOUT_SECONDS: int = 900 (0 disables eviction)
    • MODEL_IDLE_CHECK_INTERVAL_SECONDS: int = 60
    • MODEL_EVICTION_ENABLED: bool = True
  • Bind to env: MODEL_IDLE_TIMEOUT_SECONDS, MODEL_IDLE_CHECK_INTERVAL_SECONDS, MODEL_EVICTION_ENABLED.

Design

ModelManager (Singleton)

File: backend/app/services/model_manager.py (new)

  • Responsibilities:
    • Manage lifecycle (load/unload) of the TTS model/pipeline.
    • Provide get() that returns a ready model (lazy-load if needed) and updates last_used.
    • Track active request count to block eviction while > 0.
  • Internals:
    • self._model (or components), self._last_used: float, self._active: int.
    • Locks: asyncio.Lock for load/unload; asyncio.Lock or asyncio.Semaphore for counters.
    • Optional CUDA cleanup: torch.cuda.empty_cache() after unload.
  • API:
    • async def get(self) -> Model: ensures loaded; bumps last_used.
    • async def load(self): idempotent; guarded by lock.
    • async def unload(self): only when self._active == 0; clears refs and caches.
    • def touch(self): update last_used.
    • Context helper: async def using(self): async context manager incrementing/decrementing active safely.

Idle Reaper Task

Registration: FastAPI startup (e.g., in backend/app/main.py)

  • Background task loop every MODEL_IDLE_CHECK_INTERVAL_SECONDS:
    • If eviction enabled and timeout > 0 and model is loaded and active == 0 and now - last_used >= timeout, call unload().
  • Handle cancellation on shutdown.

API Integration

  • Replace direct model access in endpoints with:
    manager = ModelManager.instance()
    async with manager.using():
        model = await manager.get()
        # perform inference
    
  • Optionally call manager.touch() at request start for non-inference paths that still need the model resident.

Pseudocode

# services/model_manager.py
import time, asyncio
from typing import Optional
from .config import settings

class ModelManager:
    _instance: Optional["ModelManager"] = None

    def __init__(self):
        self._model = None
        self._last_used = time.time()
        self._active = 0
        self._lock = asyncio.Lock()
        self._counter_lock = asyncio.Lock()

    @classmethod
    def instance(cls):
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    async def load(self):
        async with self._lock:
            if self._model is not None:
                return
            # ... load model/pipeline here ...
            self._model = await load_pipeline()
            self._last_used = time.time()

    async def unload(self):
        async with self._lock:
            if self._model is None:
                return
            if self._active > 0:
                return  # safety: do not unload while in use
            # ... free resources ...
            self._model = None
            try:
                import torch
                torch.cuda.empty_cache()
            except Exception:
                pass

    async def get(self):
        if self._model is None:
            await self.load()
        self._last_used = time.time()
        return self._model

    async def _inc(self):
        async with self._counter_lock:
            self._active += 1

    async def _dec(self):
        async with self._counter_lock:
            self._active = max(0, self._active - 1)
            self._last_used = time.time()

    def last_used(self):
        return self._last_used

    def is_loaded(self):
        return self._model is not None

    def active(self):
        return self._active

    def using(self):
        manager = self
        class _Ctx:
            async def __aenter__(self):
                await manager._inc()
                return manager
            async def __aexit__(self, exc_type, exc, tb):
                await manager._dec()
        return _Ctx()

# main.py (startup)
@app.on_event("startup")
async def start_reaper():
    async def reaper():
        while True:
            try:
                await asyncio.sleep(settings.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
                if not settings.MODEL_EVICTION_ENABLED:
                    continue
                timeout = settings.MODEL_IDLE_TIMEOUT_SECONDS
                if timeout <= 0:
                    continue
                m = ModelManager.instance()
                if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
                    await m.unload()
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.exception("Idle reaper error: %s", e)
    app.state._model_reaper_task = asyncio.create_task(reaper())

@app.on_event("shutdown")
async def stop_reaper():
    task = getattr(app.state, "_model_reaper_task", None)
    if task:
        task.cancel()
        with contextlib.suppress(Exception):
            await task

## Observability
- Logs: model load/unload, reaper decisions, active count.
- Metrics (optional): counters and gauges (load events, active, residency time).

## Safety & Edge Cases
- Avoid unload when `active > 0`.
- Guard multiple loads/unloads with lock.
- Multi-worker servers: each worker manages its own model.
- Cold-start latency: document expected additional latency for first request after idle unload.

## Testing
- Unit tests for `ModelManager`: load/unload idempotency, counter behavior.
- Simulated reaper triggering with short timeouts.
- Endpoint tests: concurrency (N simultaneous inferences), ensure no unload mid-flight.

## Rollout Plan
1. Introduce config + Manager (no reaper), switch endpoints to `using()`.
2. Enable reaper with long timeout in staging; observe logs/metrics.
3. Tune timeout; enable in production.

## Tasks Checklist
- [ ] Add config flags and defaults in `backend/app/config.py`.
- [ ] Create `backend/app/services/model_manager.py`.
- [ ] Register startup/shutdown reaper in app init (`backend/app/main.py`).
- [ ] Refactor endpoints to use `ModelManager.instance().using()` and `get()`.
- [ ] Add logs and optional metrics.
- [ ] Add unit/integration tests.
- [ ] Update README/ops docs.

## Alternatives Considered
- Gunicorn/uvicorn worker preloading with external idle supervisor: more complexity, less portability.
- OS-level cgroup memory pressure eviction: opaque and risky for correctness.

## Configuration Examples

MODEL_EVICTION_ENABLED=true MODEL_IDLE_TIMEOUT_SECONDS=900 MODEL_IDLE_CHECK_INTERVAL_SECONDS=60