6.9 KiB
6.9 KiB
Unload Model on Idle: Implementation Plan
Goals
- Automatically unload large TTS model(s) when idle to reduce RAM/VRAM usage.
- Lazy-load on demand without breaking API semantics.
- Configurable timeout and safety controls.
Requirements
- Config-driven idle timeout and poll interval.
- Thread-/async-safe across concurrent requests.
- No unload while an inference is in progress.
- Clear logs and metrics for load/unload events.
Configuration
File: backend/app/config.py
- Add:
MODEL_IDLE_TIMEOUT_SECONDS: int = 900
(0 disables eviction)MODEL_IDLE_CHECK_INTERVAL_SECONDS: int = 60
MODEL_EVICTION_ENABLED: bool = True
- Bind to env:
MODEL_IDLE_TIMEOUT_SECONDS
,MODEL_IDLE_CHECK_INTERVAL_SECONDS
,MODEL_EVICTION_ENABLED
.
Design
ModelManager (Singleton)
File: backend/app/services/model_manager.py
(new)
- Responsibilities:
- Manage lifecycle (load/unload) of the TTS model/pipeline.
- Provide
get()
that returns a ready model (lazy-load if needed) and updateslast_used
. - Track active request count to block eviction while > 0.
- Internals:
self._model
(or components),self._last_used: float
,self._active: int
.- Locks:
asyncio.Lock
for load/unload;asyncio.Lock
orasyncio.Semaphore
for counters. - Optional CUDA cleanup:
torch.cuda.empty_cache()
after unload.
- API:
async def get(self) -> Model
: ensures loaded; bumpslast_used
.async def load(self)
: idempotent; guarded by lock.async def unload(self)
: only whenself._active == 0
; clears refs and caches.def touch(self)
: updatelast_used
.- Context helper:
async def using(self)
: async context manager incrementing/decrementingactive
safely.
Idle Reaper Task
Registration: FastAPI startup (e.g., in backend/app/main.py
)
- Background task loop every
MODEL_IDLE_CHECK_INTERVAL_SECONDS
:- If eviction enabled and timeout > 0 and model is loaded and
active == 0
andnow - last_used >= timeout
, callunload()
.
- If eviction enabled and timeout > 0 and model is loaded and
- Handle cancellation on shutdown.
API Integration
- Replace direct model access in endpoints with:
manager = ModelManager.instance() async with manager.using(): model = await manager.get() # perform inference
- Optionally call
manager.touch()
at request start for non-inference paths that still need the model resident.
Pseudocode
# services/model_manager.py
import time, asyncio
from typing import Optional
from .config import settings
class ModelManager:
_instance: Optional["ModelManager"] = None
def __init__(self):
self._model = None
self._last_used = time.time()
self._active = 0
self._lock = asyncio.Lock()
self._counter_lock = asyncio.Lock()
@classmethod
def instance(cls):
if not cls._instance:
cls._instance = cls()
return cls._instance
async def load(self):
async with self._lock:
if self._model is not None:
return
# ... load model/pipeline here ...
self._model = await load_pipeline()
self._last_used = time.time()
async def unload(self):
async with self._lock:
if self._model is None:
return
if self._active > 0:
return # safety: do not unload while in use
# ... free resources ...
self._model = None
try:
import torch
torch.cuda.empty_cache()
except Exception:
pass
async def get(self):
if self._model is None:
await self.load()
self._last_used = time.time()
return self._model
async def _inc(self):
async with self._counter_lock:
self._active += 1
async def _dec(self):
async with self._counter_lock:
self._active = max(0, self._active - 1)
self._last_used = time.time()
def last_used(self):
return self._last_used
def is_loaded(self):
return self._model is not None
def active(self):
return self._active
def using(self):
manager = self
class _Ctx:
async def __aenter__(self):
await manager._inc()
return manager
async def __aexit__(self, exc_type, exc, tb):
await manager._dec()
return _Ctx()
# main.py (startup)
@app.on_event("startup")
async def start_reaper():
async def reaper():
while True:
try:
await asyncio.sleep(settings.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
if not settings.MODEL_EVICTION_ENABLED:
continue
timeout = settings.MODEL_IDLE_TIMEOUT_SECONDS
if timeout <= 0:
continue
m = ModelManager.instance()
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
await m.unload()
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Idle reaper error: %s", e)
app.state._model_reaper_task = asyncio.create_task(reaper())
@app.on_event("shutdown")
async def stop_reaper():
task = getattr(app.state, "_model_reaper_task", None)
if task:
task.cancel()
with contextlib.suppress(Exception):
await task
## Observability
- Logs: model load/unload, reaper decisions, active count.
- Metrics (optional): counters and gauges (load events, active, residency time).
## Safety & Edge Cases
- Avoid unload when `active > 0`.
- Guard multiple loads/unloads with lock.
- Multi-worker servers: each worker manages its own model.
- Cold-start latency: document expected additional latency for first request after idle unload.
## Testing
- Unit tests for `ModelManager`: load/unload idempotency, counter behavior.
- Simulated reaper triggering with short timeouts.
- Endpoint tests: concurrency (N simultaneous inferences), ensure no unload mid-flight.
## Rollout Plan
1. Introduce config + Manager (no reaper), switch endpoints to `using()`.
2. Enable reaper with long timeout in staging; observe logs/metrics.
3. Tune timeout; enable in production.
## Tasks Checklist
- [ ] Add config flags and defaults in `backend/app/config.py`.
- [ ] Create `backend/app/services/model_manager.py`.
- [ ] Register startup/shutdown reaper in app init (`backend/app/main.py`).
- [ ] Refactor endpoints to use `ModelManager.instance().using()` and `get()`.
- [ ] Add logs and optional metrics.
- [ ] Add unit/integration tests.
- [ ] Update README/ops docs.
## Alternatives Considered
- Gunicorn/uvicorn worker preloading with external idle supervisor: more complexity, less portability.
- OS-level cgroup memory pressure eviction: opaque and risky for correctness.
## Configuration Examples
MODEL_EVICTION_ENABLED=true MODEL_IDLE_TIMEOUT_SECONDS=900 MODEL_IDLE_CHECK_INTERVAL_SECONDS=60