chatterbox-ui/.note/concurrency_plan.md

189 lines
7.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan
Date: 2025-08-14
Owner: Backend
Status: Proposed (ready to implement)
## Goals
- Increase GPU utilization and reduce wall-clock time for dialog generation.
- Keep model lifecycle stable (leveraging current `ModelManager`).
- Minimal-risk changes: no API shape changes to clients.
## Scope
- Implement bounded concurrency for per-line speech chunk generation within a single dialog request.
- Offload audio file writes to threads to overlap GPU compute and disk I/O.
- Add configuration knobs to tune concurrency.
## Current State (References)
- `backend/app/services/dialog_processor_service.py`
- `DialogProcessorService.process_dialog()` iterates items and awaits `tts_service.generate_speech(...)` sequentially (lines ~171201).
- `backend/app/services/tts_service.py`
- `TTSService.generate_speech()` runs the TTS forward and calls `torchaudio.save(...)` on the event loop thread (blocking).
- `backend/app/services/model_manager.py`
- `ModelManager.using()` tracks active work; prevents idle eviction during requests.
- `backend/app/routers/dialog.py`
- `process_dialog_flow()` expects ordered `segment_files` and then concatenates; good to keep order stable.
## Design Overview
1) Bounded concurrency at dialog level
- Plan all output segments with a stable `segment_idx` (including speech chunks, silence, and reused audio).
- For speech chunks, schedule concurrent async tasks with a global semaphore set by config `TTS_MAX_CONCURRENCY` (start at 34).
- Await all tasks and collate results by `segment_idx` to preserve order.
2) File I/O offload
- Replace direct `torchaudio.save(...)` with `await asyncio.to_thread(torchaudio.save, ...)` in `TTSService.generate_speech()`.
- This lets the next GPU forward start while previous file writes happen on worker threads.
## Configuration
Add to `backend/app/config.py`:
- `TTS_MAX_CONCURRENCY: int` (default: `int(os.getenv("TTS_MAX_CONCURRENCY", "3"))`).
- Optional (future): `TTS_ENABLE_AMP_ON_CUDA: bool = True` to allow mixed precision on CUDA only.
## Implementation Steps
### A. Dialog-level concurrency
- File: `backend/app/services/dialog_processor_service.py`
- Function: `DialogProcessorService.process_dialog()`
1. Planning pass to assign indices
- Iterate `dialog_items` and build a list `planned_segments` entries:
- For silence or reuse: immediately append a final result with assigned `segment_idx` and continue.
- For speech: split into `text_chunks`; for each chunk create a planned entry: `{ segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }`.
- Increment `segment_idx` for every planned segment (speech chunk or silence/reuse) to preserve final order.
2. Concurrency setup
- Create `sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY)`.
- For each planned speech segment, create a task with an inner wrapper:
```python
async def run_one(planned):
async with sem:
try:
out_path = await self.tts_service.generate_speech(
text=planned.text_chunk,
speaker_sample_path=planned.abs_speaker_sample_path,
output_filename_base=planned.filename_base,
output_dir=dialog_temp_dir,
exaggeration=planned.exaggeration,
cfg_weight=planned.cfg_weight,
temperature=planned.temperature,
)
return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk}
except Exception as e:
return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk}
```
- Schedule with `asyncio.create_task(run_one(p))` and collect tasks.
3. Await and collate
- `results_map = {}`; for each completed task, set `results_map[idx] = payload`.
- Merge: start with all previously final (silence/reuse/error) entries placed by `segment_idx`, then fill speech results by `segment_idx` into a single `segment_results` list sorted ascending by index.
- Keep `processing_log` entries for each planned segment (queued, started, finished, errors).
4. Return value unchanged
- Return `{"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}`. This maintains router and concatenator behavior.
### B. Offload audio writes
- File: `backend/app/services/tts_service.py`
- Function: `TTSService.generate_speech()`
1. After obtaining `wav` tensor, replace:
```python
# torchaudio.save(str(output_file_path), wav, self.model.sr)
```
with:
```python
await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr)
```
- Keep the rest of cleanup logic (delete `wav`, `gc.collect()`, cache emptying) unchanged.
2. Optional (CUDA-only AMP)
- If CUDA is used and `config.TTS_ENABLE_AMP_ON_CUDA` is True, wrap forward with AMP:
```python
with torch.cuda.amp.autocast(dtype=torch.float16):
wav = self.model.generate(...)
```
- Leave MPS/CPU code path as-is.
## Error Handling & Ordering
- Every planned segment owns a unique `segment_idx`.
- On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already.
- Preserve exact output order expected by `routers/dialog.py::process_dialog_flow()`.
## Performance Expectations
- GPU util should increase from ~50% to 7590% depending on dialog size and line lengths.
- Wall-clock reduction is workload-dependent; target 1.52.5x on multi-line dialogs.
## Metrics & Instrumentation
- Add timestamped log entries per segment: planned→queued→started→saved.
- Log effective concurrency (max in-flight), and cumulative GPU time if available.
- Optionally add a simple timing summary at end of `process_dialog()`.
## Testing Plan
1. Unit-ish
- Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist.
- Introduce an invalid speaker to verify error propagation doesnt break the rest.
2. Integration
- POST `/api/dialog/generate` with 2050 mixed-length lines and a couple silences.
- Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved.
- Compare runtime vs. sequential baseline (before/after).
3. Stress/limits
- Long lines split into many chunks; verify no OOM with `TTS_MAX_CONCURRENCY`=3.
- Try `TTS_MAX_CONCURRENCY`=1 to simulate sequential; compare metrics.
## Rollout & Config Defaults
- Default `TTS_MAX_CONCURRENCY=3`.
- Expose via environment variable; no client changes needed.
- If instability observed, set `TTS_MAX_CONCURRENCY=1` to revert to sequential behavior quickly.
## Risks & Mitigations
- OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place.
- Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency.
- Model thread safety → We call `model.generate` concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal.
## Follow-up (Out of Scope for this change)
- Dynamic batching queue inside `TTSService` for further GPU efficiency.
- CUDA AMP enablement and profiling.
- Per-speaker sub-queues if batching requires same-speaker inputs.
## Acceptance Criteria
- `TTS_MAX_CONCURRENCY` is configurable; default=3.
- File writes occur via `asyncio.to_thread`.
- Order of `segment_files` unchanged relative to sequential output.
- End-to-end works for both small and large dialogs; error cases logged.
- Observed GPU utilization and runtime improve on representative dialog.