chatterbox-ui/.note/concurrency_plan.md

7.6 KiB
Raw Blame History

Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan

Date: 2025-08-14 Owner: Backend Status: Proposed (ready to implement)

Goals

  • Increase GPU utilization and reduce wall-clock time for dialog generation.
  • Keep model lifecycle stable (leveraging current ModelManager).
  • Minimal-risk changes: no API shape changes to clients.

Scope

  • Implement bounded concurrency for per-line speech chunk generation within a single dialog request.
  • Offload audio file writes to threads to overlap GPU compute and disk I/O.
  • Add configuration knobs to tune concurrency.

Current State (References)

  • backend/app/services/dialog_processor_service.py
    • DialogProcessorService.process_dialog() iterates items and awaits tts_service.generate_speech(...) sequentially (lines ~171201).
  • backend/app/services/tts_service.py
    • TTSService.generate_speech() runs the TTS forward and calls torchaudio.save(...) on the event loop thread (blocking).
  • backend/app/services/model_manager.py
    • ModelManager.using() tracks active work; prevents idle eviction during requests.
  • backend/app/routers/dialog.py
    • process_dialog_flow() expects ordered segment_files and then concatenates; good to keep order stable.

Design Overview

  1. Bounded concurrency at dialog level
  • Plan all output segments with a stable segment_idx (including speech chunks, silence, and reused audio).
  • For speech chunks, schedule concurrent async tasks with a global semaphore set by config TTS_MAX_CONCURRENCY (start at 34).
  • Await all tasks and collate results by segment_idx to preserve order.
  1. File I/O offload
  • Replace direct torchaudio.save(...) with await asyncio.to_thread(torchaudio.save, ...) in TTSService.generate_speech().
  • This lets the next GPU forward start while previous file writes happen on worker threads.

Configuration

Add to backend/app/config.py:

  • TTS_MAX_CONCURRENCY: int (default: int(os.getenv("TTS_MAX_CONCURRENCY", "3"))).
  • Optional (future): TTS_ENABLE_AMP_ON_CUDA: bool = True to allow mixed precision on CUDA only.

Implementation Steps

A. Dialog-level concurrency

  • File: backend/app/services/dialog_processor_service.py
  • Function: DialogProcessorService.process_dialog()
  1. Planning pass to assign indices
  • Iterate dialog_items and build a list planned_segments entries:
    • For silence or reuse: immediately append a final result with assigned segment_idx and continue.
    • For speech: split into text_chunks; for each chunk create a planned entry: { segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }.
    • Increment segment_idx for every planned segment (speech chunk or silence/reuse) to preserve final order.
  1. Concurrency setup
  • Create sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY).

  • For each planned speech segment, create a task with an inner wrapper:

    async def run_one(planned):
        async with sem:
            try:
                out_path = await self.tts_service.generate_speech(
                    text=planned.text_chunk,
                    speaker_sample_path=planned.abs_speaker_sample_path,
                    output_filename_base=planned.filename_base,
                    output_dir=dialog_temp_dir,
                    exaggeration=planned.exaggeration,
                    cfg_weight=planned.cfg_weight,
                    temperature=planned.temperature,
                )
                return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk}
            except Exception as e:
                return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk}
    
  • Schedule with asyncio.create_task(run_one(p)) and collect tasks.

  1. Await and collate
  • results_map = {}; for each completed task, set results_map[idx] = payload.
  • Merge: start with all previously final (silence/reuse/error) entries placed by segment_idx, then fill speech results by segment_idx into a single segment_results list sorted ascending by index.
  • Keep processing_log entries for each planned segment (queued, started, finished, errors).
  1. Return value unchanged
  • Return {"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}. This maintains router and concatenator behavior.

B. Offload audio writes

  • File: backend/app/services/tts_service.py
  • Function: TTSService.generate_speech()
  1. After obtaining wav tensor, replace:
# torchaudio.save(str(output_file_path), wav, self.model.sr)

with:

await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr)
  • Keep the rest of cleanup logic (delete wav, gc.collect(), cache emptying) unchanged.
  1. Optional (CUDA-only AMP)
  • If CUDA is used and config.TTS_ENABLE_AMP_ON_CUDA is True, wrap forward with AMP:
with torch.cuda.amp.autocast(dtype=torch.float16):
    wav = self.model.generate(...)
  • Leave MPS/CPU code path as-is.

Error Handling & Ordering

  • Every planned segment owns a unique segment_idx.
  • On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already.
  • Preserve exact output order expected by routers/dialog.py::process_dialog_flow().

Performance Expectations

  • GPU util should increase from ~50% to 7590% depending on dialog size and line lengths.
  • Wall-clock reduction is workload-dependent; target 1.52.5x on multi-line dialogs.

Metrics & Instrumentation

  • Add timestamped log entries per segment: planned→queued→started→saved.
  • Log effective concurrency (max in-flight), and cumulative GPU time if available.
  • Optionally add a simple timing summary at end of process_dialog().

Testing Plan

  1. Unit-ish
  • Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist.
  • Introduce an invalid speaker to verify error propagation doesnt break the rest.
  1. Integration
  • POST /api/dialog/generate with 2050 mixed-length lines and a couple silences.
  • Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved.
  • Compare runtime vs. sequential baseline (before/after).
  1. Stress/limits
  • Long lines split into many chunks; verify no OOM with TTS_MAX_CONCURRENCY=3.
  • Try TTS_MAX_CONCURRENCY=1 to simulate sequential; compare metrics.

Rollout & Config Defaults

  • Default TTS_MAX_CONCURRENCY=3.
  • Expose via environment variable; no client changes needed.
  • If instability observed, set TTS_MAX_CONCURRENCY=1 to revert to sequential behavior quickly.

Risks & Mitigations

  • OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place.
  • Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency.
  • Model thread safety → We call model.generate concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal.

Follow-up (Out of Scope for this change)

  • Dynamic batching queue inside TTSService for further GPU efficiency.
  • CUDA AMP enablement and profiling.
  • Per-speaker sub-queues if batching requires same-speaker inputs.

Acceptance Criteria

  • TTS_MAX_CONCURRENCY is configurable; default=3.
  • File writes occur via asyncio.to_thread.
  • Order of segment_files unchanged relative to sequential output.
  • End-to-end works for both small and large dialogs; error cases logged.
  • Observed GPU utilization and runtime improve on representative dialog.