7.6 KiB
7.6 KiB
Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan
Date: 2025-08-14 Owner: Backend Status: Proposed (ready to implement)
Goals
- Increase GPU utilization and reduce wall-clock time for dialog generation.
- Keep model lifecycle stable (leveraging current
ModelManager
). - Minimal-risk changes: no API shape changes to clients.
Scope
- Implement bounded concurrency for per-line speech chunk generation within a single dialog request.
- Offload audio file writes to threads to overlap GPU compute and disk I/O.
- Add configuration knobs to tune concurrency.
Current State (References)
backend/app/services/dialog_processor_service.py
DialogProcessorService.process_dialog()
iterates items and awaitstts_service.generate_speech(...)
sequentially (lines ~171–201).
backend/app/services/tts_service.py
TTSService.generate_speech()
runs the TTS forward and callstorchaudio.save(...)
on the event loop thread (blocking).
backend/app/services/model_manager.py
ModelManager.using()
tracks active work; prevents idle eviction during requests.
backend/app/routers/dialog.py
process_dialog_flow()
expects orderedsegment_files
and then concatenates; good to keep order stable.
Design Overview
- Bounded concurrency at dialog level
- Plan all output segments with a stable
segment_idx
(including speech chunks, silence, and reused audio). - For speech chunks, schedule concurrent async tasks with a global semaphore set by config
TTS_MAX_CONCURRENCY
(start at 3–4). - Await all tasks and collate results by
segment_idx
to preserve order.
- File I/O offload
- Replace direct
torchaudio.save(...)
withawait asyncio.to_thread(torchaudio.save, ...)
inTTSService.generate_speech()
. - This lets the next GPU forward start while previous file writes happen on worker threads.
Configuration
Add to backend/app/config.py
:
TTS_MAX_CONCURRENCY: int
(default:int(os.getenv("TTS_MAX_CONCURRENCY", "3"))
).- Optional (future):
TTS_ENABLE_AMP_ON_CUDA: bool = True
to allow mixed precision on CUDA only.
Implementation Steps
A. Dialog-level concurrency
- File:
backend/app/services/dialog_processor_service.py
- Function:
DialogProcessorService.process_dialog()
- Planning pass to assign indices
- Iterate
dialog_items
and build a listplanned_segments
entries:- For silence or reuse: immediately append a final result with assigned
segment_idx
and continue. - For speech: split into
text_chunks
; for each chunk create a planned entry:{ segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }
. - Increment
segment_idx
for every planned segment (speech chunk or silence/reuse) to preserve final order.
- For silence or reuse: immediately append a final result with assigned
- Concurrency setup
-
Create
sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY)
. -
For each planned speech segment, create a task with an inner wrapper:
async def run_one(planned): async with sem: try: out_path = await self.tts_service.generate_speech( text=planned.text_chunk, speaker_sample_path=planned.abs_speaker_sample_path, output_filename_base=planned.filename_base, output_dir=dialog_temp_dir, exaggeration=planned.exaggeration, cfg_weight=planned.cfg_weight, temperature=planned.temperature, ) return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk} except Exception as e: return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk}
-
Schedule with
asyncio.create_task(run_one(p))
and collect tasks.
- Await and collate
results_map = {}
; for each completed task, setresults_map[idx] = payload
.- Merge: start with all previously final (silence/reuse/error) entries placed by
segment_idx
, then fill speech results bysegment_idx
into a singlesegment_results
list sorted ascending by index. - Keep
processing_log
entries for each planned segment (queued, started, finished, errors).
- Return value unchanged
- Return
{"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}
. This maintains router and concatenator behavior.
B. Offload audio writes
- File:
backend/app/services/tts_service.py
- Function:
TTSService.generate_speech()
- After obtaining
wav
tensor, replace:
# torchaudio.save(str(output_file_path), wav, self.model.sr)
with:
await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr)
- Keep the rest of cleanup logic (delete
wav
,gc.collect()
, cache emptying) unchanged.
- Optional (CUDA-only AMP)
- If CUDA is used and
config.TTS_ENABLE_AMP_ON_CUDA
is True, wrap forward with AMP:
with torch.cuda.amp.autocast(dtype=torch.float16):
wav = self.model.generate(...)
- Leave MPS/CPU code path as-is.
Error Handling & Ordering
- Every planned segment owns a unique
segment_idx
. - On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already.
- Preserve exact output order expected by
routers/dialog.py::process_dialog_flow()
.
Performance Expectations
- GPU util should increase from ~50% to 75–90% depending on dialog size and line lengths.
- Wall-clock reduction is workload-dependent; target 1.5–2.5x on multi-line dialogs.
Metrics & Instrumentation
- Add timestamped log entries per segment: planned→queued→started→saved.
- Log effective concurrency (max in-flight), and cumulative GPU time if available.
- Optionally add a simple timing summary at end of
process_dialog()
.
Testing Plan
- Unit-ish
- Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist.
- Introduce an invalid speaker to verify error propagation doesn’t break the rest.
- Integration
- POST
/api/dialog/generate
with 20–50 mixed-length lines and a couple silences. - Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved.
- Compare runtime vs. sequential baseline (before/after).
- Stress/limits
- Long lines split into many chunks; verify no OOM with
TTS_MAX_CONCURRENCY
=3. - Try
TTS_MAX_CONCURRENCY
=1 to simulate sequential; compare metrics.
Rollout & Config Defaults
- Default
TTS_MAX_CONCURRENCY=3
. - Expose via environment variable; no client changes needed.
- If instability observed, set
TTS_MAX_CONCURRENCY=1
to revert to sequential behavior quickly.
Risks & Mitigations
- OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place.
- Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency.
- Model thread safety → We call
model.generate
concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal.
Follow-up (Out of Scope for this change)
- Dynamic batching queue inside
TTSService
for further GPU efficiency. - CUDA AMP enablement and profiling.
- Per-speaker sub-queues if batching requires same-speaker inputs.
Acceptance Criteria
TTS_MAX_CONCURRENCY
is configurable; default=3.- File writes occur via
asyncio.to_thread
. - Order of
segment_files
unchanged relative to sequential output. - End-to-end works for both small and large dialogs; error cases logged.
- Observed GPU utilization and runtime improve on representative dialog.