205 lines
6.9 KiB
Markdown
205 lines
6.9 KiB
Markdown
# Unload Model on Idle: Implementation Plan
|
|
|
|
## Goals
|
|
- Automatically unload large TTS model(s) when idle to reduce RAM/VRAM usage.
|
|
- Lazy-load on demand without breaking API semantics.
|
|
- Configurable timeout and safety controls.
|
|
|
|
## Requirements
|
|
- Config-driven idle timeout and poll interval.
|
|
- Thread-/async-safe across concurrent requests.
|
|
- No unload while an inference is in progress.
|
|
- Clear logs and metrics for load/unload events.
|
|
|
|
## Configuration
|
|
File: `backend/app/config.py`
|
|
- Add:
|
|
- `MODEL_IDLE_TIMEOUT_SECONDS: int = 900` (0 disables eviction)
|
|
- `MODEL_IDLE_CHECK_INTERVAL_SECONDS: int = 60`
|
|
- `MODEL_EVICTION_ENABLED: bool = True`
|
|
- Bind to env: `MODEL_IDLE_TIMEOUT_SECONDS`, `MODEL_IDLE_CHECK_INTERVAL_SECONDS`, `MODEL_EVICTION_ENABLED`.
|
|
|
|
## Design
|
|
### ModelManager (Singleton)
|
|
File: `backend/app/services/model_manager.py` (new)
|
|
- Responsibilities:
|
|
- Manage lifecycle (load/unload) of the TTS model/pipeline.
|
|
- Provide `get()` that returns a ready model (lazy-load if needed) and updates `last_used`.
|
|
- Track active request count to block eviction while > 0.
|
|
- Internals:
|
|
- `self._model` (or components), `self._last_used: float`, `self._active: int`.
|
|
- Locks: `asyncio.Lock` for load/unload; `asyncio.Lock` or `asyncio.Semaphore` for counters.
|
|
- Optional CUDA cleanup: `torch.cuda.empty_cache()` after unload.
|
|
- API:
|
|
- `async def get(self) -> Model`: ensures loaded; bumps `last_used`.
|
|
- `async def load(self)`: idempotent; guarded by lock.
|
|
- `async def unload(self)`: only when `self._active == 0`; clears refs and caches.
|
|
- `def touch(self)`: update `last_used`.
|
|
- Context helper: `async def using(self)`: async context manager incrementing/decrementing `active` safely.
|
|
|
|
### Idle Reaper Task
|
|
Registration: FastAPI startup (e.g., in `backend/app/main.py`)
|
|
- Background task loop every `MODEL_IDLE_CHECK_INTERVAL_SECONDS`:
|
|
- If eviction enabled and timeout > 0 and model is loaded and `active == 0` and `now - last_used >= timeout`, call `unload()`.
|
|
- Handle cancellation on shutdown.
|
|
|
|
### API Integration
|
|
- Replace direct model access in endpoints with:
|
|
```python
|
|
manager = ModelManager.instance()
|
|
async with manager.using():
|
|
model = await manager.get()
|
|
# perform inference
|
|
```
|
|
- Optionally call `manager.touch()` at request start for non-inference paths that still need the model resident.
|
|
|
|
## Pseudocode
|
|
```python
|
|
# services/model_manager.py
|
|
import time, asyncio
|
|
from typing import Optional
|
|
from .config import settings
|
|
|
|
class ModelManager:
|
|
_instance: Optional["ModelManager"] = None
|
|
|
|
def __init__(self):
|
|
self._model = None
|
|
self._last_used = time.time()
|
|
self._active = 0
|
|
self._lock = asyncio.Lock()
|
|
self._counter_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
def instance(cls):
|
|
if not cls._instance:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
async def load(self):
|
|
async with self._lock:
|
|
if self._model is not None:
|
|
return
|
|
# ... load model/pipeline here ...
|
|
self._model = await load_pipeline()
|
|
self._last_used = time.time()
|
|
|
|
async def unload(self):
|
|
async with self._lock:
|
|
if self._model is None:
|
|
return
|
|
if self._active > 0:
|
|
return # safety: do not unload while in use
|
|
# ... free resources ...
|
|
self._model = None
|
|
try:
|
|
import torch
|
|
torch.cuda.empty_cache()
|
|
except Exception:
|
|
pass
|
|
|
|
async def get(self):
|
|
if self._model is None:
|
|
await self.load()
|
|
self._last_used = time.time()
|
|
return self._model
|
|
|
|
async def _inc(self):
|
|
async with self._counter_lock:
|
|
self._active += 1
|
|
|
|
async def _dec(self):
|
|
async with self._counter_lock:
|
|
self._active = max(0, self._active - 1)
|
|
self._last_used = time.time()
|
|
|
|
def last_used(self):
|
|
return self._last_used
|
|
|
|
def is_loaded(self):
|
|
return self._model is not None
|
|
|
|
def active(self):
|
|
return self._active
|
|
|
|
def using(self):
|
|
manager = self
|
|
class _Ctx:
|
|
async def __aenter__(self):
|
|
await manager._inc()
|
|
return manager
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await manager._dec()
|
|
return _Ctx()
|
|
|
|
# main.py (startup)
|
|
@app.on_event("startup")
|
|
async def start_reaper():
|
|
async def reaper():
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(settings.MODEL_IDLE_CHECK_INTERVAL_SECONDS)
|
|
if not settings.MODEL_EVICTION_ENABLED:
|
|
continue
|
|
timeout = settings.MODEL_IDLE_TIMEOUT_SECONDS
|
|
if timeout <= 0:
|
|
continue
|
|
m = ModelManager.instance()
|
|
if m.is_loaded() and m.active() == 0 and (time.time() - m.last_used()) >= timeout:
|
|
await m.unload()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.exception("Idle reaper error: %s", e)
|
|
app.state._model_reaper_task = asyncio.create_task(reaper())
|
|
|
|
@app.on_event("shutdown")
|
|
async def stop_reaper():
|
|
task = getattr(app.state, "_model_reaper_task", None)
|
|
if task:
|
|
task.cancel()
|
|
with contextlib.suppress(Exception):
|
|
await task
|
|
```
|
|
```
|
|
|
|
## Observability
|
|
- Logs: model load/unload, reaper decisions, active count.
|
|
- Metrics (optional): counters and gauges (load events, active, residency time).
|
|
|
|
## Safety & Edge Cases
|
|
- Avoid unload when `active > 0`.
|
|
- Guard multiple loads/unloads with lock.
|
|
- Multi-worker servers: each worker manages its own model.
|
|
- Cold-start latency: document expected additional latency for first request after idle unload.
|
|
|
|
## Testing
|
|
- Unit tests for `ModelManager`: load/unload idempotency, counter behavior.
|
|
- Simulated reaper triggering with short timeouts.
|
|
- Endpoint tests: concurrency (N simultaneous inferences), ensure no unload mid-flight.
|
|
|
|
## Rollout Plan
|
|
1. Introduce config + Manager (no reaper), switch endpoints to `using()`.
|
|
2. Enable reaper with long timeout in staging; observe logs/metrics.
|
|
3. Tune timeout; enable in production.
|
|
|
|
## Tasks Checklist
|
|
- [ ] Add config flags and defaults in `backend/app/config.py`.
|
|
- [ ] Create `backend/app/services/model_manager.py`.
|
|
- [ ] Register startup/shutdown reaper in app init (`backend/app/main.py`).
|
|
- [ ] Refactor endpoints to use `ModelManager.instance().using()` and `get()`.
|
|
- [ ] Add logs and optional metrics.
|
|
- [ ] Add unit/integration tests.
|
|
- [ ] Update README/ops docs.
|
|
|
|
## Alternatives Considered
|
|
- Gunicorn/uvicorn worker preloading with external idle supervisor: more complexity, less portability.
|
|
- OS-level cgroup memory pressure eviction: opaque and risky for correctness.
|
|
|
|
## Configuration Examples
|
|
```
|
|
MODEL_EVICTION_ENABLED=true
|
|
MODEL_IDLE_TIMEOUT_SECONDS=900
|
|
MODEL_IDLE_CHECK_INTERVAL_SECONDS=60
|
|
```
|