diff --git a/.note/concurrency_plan.md b/.note/concurrency_plan.md new file mode 100644 index 0000000..a7e9f77 --- /dev/null +++ b/.note/concurrency_plan.md @@ -0,0 +1,188 @@ +# Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan + +Date: 2025-08-14 +Owner: Backend +Status: Proposed (ready to implement) + +## Goals + +- Increase GPU utilization and reduce wall-clock time for dialog generation. +- Keep model lifecycle stable (leveraging current `ModelManager`). +- Minimal-risk changes: no API shape changes to clients. + +## Scope + +- Implement bounded concurrency for per-line speech chunk generation within a single dialog request. +- Offload audio file writes to threads to overlap GPU compute and disk I/O. +- Add configuration knobs to tune concurrency. + +## Current State (References) + +- `backend/app/services/dialog_processor_service.py` + - `DialogProcessorService.process_dialog()` iterates items and awaits `tts_service.generate_speech(...)` sequentially (lines ~171–201). +- `backend/app/services/tts_service.py` + - `TTSService.generate_speech()` runs the TTS forward and calls `torchaudio.save(...)` on the event loop thread (blocking). +- `backend/app/services/model_manager.py` + - `ModelManager.using()` tracks active work; prevents idle eviction during requests. +- `backend/app/routers/dialog.py` + - `process_dialog_flow()` expects ordered `segment_files` and then concatenates; good to keep order stable. + +## Design Overview + +1) Bounded concurrency at dialog level + +- Plan all output segments with a stable `segment_idx` (including speech chunks, silence, and reused audio). +- For speech chunks, schedule concurrent async tasks with a global semaphore set by config `TTS_MAX_CONCURRENCY` (start at 3–4). +- Await all tasks and collate results by `segment_idx` to preserve order. + +2) File I/O offload + +- Replace direct `torchaudio.save(...)` with `await asyncio.to_thread(torchaudio.save, ...)` in `TTSService.generate_speech()`. +- This lets the next GPU forward start while previous file writes happen on worker threads. + +## Configuration + +Add to `backend/app/config.py`: + +- `TTS_MAX_CONCURRENCY: int` (default: `int(os.getenv("TTS_MAX_CONCURRENCY", "3"))`). +- Optional (future): `TTS_ENABLE_AMP_ON_CUDA: bool = True` to allow mixed precision on CUDA only. + +## Implementation Steps + +### A. Dialog-level concurrency + +- File: `backend/app/services/dialog_processor_service.py` +- Function: `DialogProcessorService.process_dialog()` + +1. Planning pass to assign indices + +- Iterate `dialog_items` and build a list `planned_segments` entries: + - For silence or reuse: immediately append a final result with assigned `segment_idx` and continue. + - For speech: split into `text_chunks`; for each chunk create a planned entry: `{ segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }`. + - Increment `segment_idx` for every planned segment (speech chunk or silence/reuse) to preserve final order. + +2. Concurrency setup + +- Create `sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY)`. +- For each planned speech segment, create a task with an inner wrapper: + + ```python + async def run_one(planned): + async with sem: + try: + out_path = await self.tts_service.generate_speech( + text=planned.text_chunk, + speaker_sample_path=planned.abs_speaker_sample_path, + output_filename_base=planned.filename_base, + output_dir=dialog_temp_dir, + exaggeration=planned.exaggeration, + cfg_weight=planned.cfg_weight, + temperature=planned.temperature, + ) + return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk} + except Exception as e: + return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk} + ``` + +- Schedule with `asyncio.create_task(run_one(p))` and collect tasks. + +3. Await and collate + +- `results_map = {}`; for each completed task, set `results_map[idx] = payload`. +- Merge: start with all previously final (silence/reuse/error) entries placed by `segment_idx`, then fill speech results by `segment_idx` into a single `segment_results` list sorted ascending by index. +- Keep `processing_log` entries for each planned segment (queued, started, finished, errors). + +4. Return value unchanged + +- Return `{"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}`. This maintains router and concatenator behavior. + +### B. Offload audio writes + +- File: `backend/app/services/tts_service.py` +- Function: `TTSService.generate_speech()` + +1. After obtaining `wav` tensor, replace: + +```python +# torchaudio.save(str(output_file_path), wav, self.model.sr) +``` + +with: + +```python +await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr) +``` + +- Keep the rest of cleanup logic (delete `wav`, `gc.collect()`, cache emptying) unchanged. + +2. Optional (CUDA-only AMP) + +- If CUDA is used and `config.TTS_ENABLE_AMP_ON_CUDA` is True, wrap forward with AMP: + +```python +with torch.cuda.amp.autocast(dtype=torch.float16): + wav = self.model.generate(...) +``` + +- Leave MPS/CPU code path as-is. + +## Error Handling & Ordering + +- Every planned segment owns a unique `segment_idx`. +- On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already. +- Preserve exact output order expected by `routers/dialog.py::process_dialog_flow()`. + +## Performance Expectations + +- GPU util should increase from ~50% to 75–90% depending on dialog size and line lengths. +- Wall-clock reduction is workload-dependent; target 1.5–2.5x on multi-line dialogs. + +## Metrics & Instrumentation + +- Add timestamped log entries per segment: planned→queued→started→saved. +- Log effective concurrency (max in-flight), and cumulative GPU time if available. +- Optionally add a simple timing summary at end of `process_dialog()`. + +## Testing Plan + +1. Unit-ish + +- Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist. +- Introduce an invalid speaker to verify error propagation doesn’t break the rest. + +2. Integration + +- POST `/api/dialog/generate` with 20–50 mixed-length lines and a couple silences. +- Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved. +- Compare runtime vs. sequential baseline (before/after). + +3. Stress/limits + +- Long lines split into many chunks; verify no OOM with `TTS_MAX_CONCURRENCY`=3. +- Try `TTS_MAX_CONCURRENCY`=1 to simulate sequential; compare metrics. + +## Rollout & Config Defaults + +- Default `TTS_MAX_CONCURRENCY=3`. +- Expose via environment variable; no client changes needed. +- If instability observed, set `TTS_MAX_CONCURRENCY=1` to revert to sequential behavior quickly. + +## Risks & Mitigations + +- OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place. +- Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency. +- Model thread safety → We call `model.generate` concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal. + +## Follow-up (Out of Scope for this change) + +- Dynamic batching queue inside `TTSService` for further GPU efficiency. +- CUDA AMP enablement and profiling. +- Per-speaker sub-queues if batching requires same-speaker inputs. + +## Acceptance Criteria + +- `TTS_MAX_CONCURRENCY` is configurable; default=3. +- File writes occur via `asyncio.to_thread`. +- Order of `segment_files` unchanged relative to sequential output. +- End-to-end works for both small and large dialogs; error cases logged. +- Observed GPU utilization and runtime improve on representative dialog. diff --git a/backend/app/config.py b/backend/app/config.py index 60e2d67..278fdbc 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -66,6 +66,10 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port: # Device configuration DEVICE = os.getenv("DEVICE", "auto") + +# Concurrency configuration +# Max number of concurrent TTS generation tasks per dialog request +TTS_MAX_CONCURRENCY = int(os.getenv("TTS_MAX_CONCURRENCY", "3")) # Model idle eviction configuration # Enable/disable idle-based model eviction diff --git a/backend/app/services/dialog_processor_service.py b/backend/app/services/dialog_processor_service.py index df69aae..4dc7af2 100644 --- a/backend/app/services/dialog_processor_service.py +++ b/backend/app/services/dialog_processor_service.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import List, Dict, Any, Union import re +import asyncio +from datetime import datetime from .tts_service import TTSService from .speaker_service import SpeakerManagementService @@ -92,24 +94,72 @@ class DialogProcessorService: import shutil segment_idx = 0 + tasks = [] + results_map: Dict[int, Dict[str, Any]] = {} + sem = asyncio.Semaphore(getattr(config, "TTS_MAX_CONCURRENCY", 2)) + + async def run_one(planned: Dict[str, Any]): + async with sem: + text_chunk = planned["text_chunk"] + speaker_id = planned["speaker_id"] + abs_speaker_sample_path = planned["abs_speaker_sample_path"] + filename_base = planned["filename_base"] + params = planned["params"] + seg_idx = planned["segment_idx"] + start_ts = datetime.now() + start_line = ( + f"[{start_ts.isoformat(timespec='seconds')}] [TTS-TASK] START seg_idx={seg_idx} " + f"speaker={speaker_id} chunk_len={len(text_chunk)} base={filename_base}" + ) + try: + out_path = await self.tts_service.generate_speech( + text=text_chunk, + speaker_id=speaker_id, + speaker_sample_path=str(abs_speaker_sample_path), + output_filename_base=filename_base, + output_dir=dialog_temp_dir, + exaggeration=params.get('exaggeration', 0.5), + cfg_weight=params.get('cfg_weight', 0.5), + temperature=params.get('temperature', 0.8), + ) + end_ts = datetime.now() + duration = (end_ts - start_ts).total_seconds() + end_line = ( + f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] END seg_idx={seg_idx} " + f"dur={duration:.2f}s -> {out_path}" + ) + return seg_idx, { + "type": "speech", + "path": str(out_path), + "speaker_id": speaker_id, + "text_chunk": text_chunk, + }, start_line + "\n" + f"Successfully generated segment: {out_path}" + "\n" + end_line + except Exception as e: + end_ts = datetime.now() + err_line = ( + f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] ERROR seg_idx={seg_idx} " + f"speaker={speaker_id} err={repr(e)}" + ) + return seg_idx, { + "type": "error", + "message": f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}", + "text_chunk": text_chunk, + }, err_line + for i, item in enumerate(dialog_items): item_type = item.get("type") processing_log.append(f"Processing item {i+1}: type='{item_type}'") - # --- Universal: Handle reuse of existing audio for both speech and silence --- + # --- Handle reuse of existing audio --- use_existing_audio = item.get("use_existing_audio", False) audio_url = item.get("audio_url") if use_existing_audio and audio_url: - # Determine source path (handle both absolute and relative) - # Map web URL to actual file location in tts_generated_dialogs if audio_url.startswith("/generated_audio/"): src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url[len("/generated_audio/"):] else: src_audio_path = Path(audio_url) if not src_audio_path.is_absolute(): - # Assume relative to the generated audio root dir src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url.lstrip("/\\") - # Now src_audio_path should point to the real file in tts_generated_dialogs if src_audio_path.is_file(): segment_filename = f"{output_base_name}_seg{segment_idx}_reused.wav" dest_path = (self.temp_audio_dir / output_base_name / segment_filename) @@ -123,22 +173,18 @@ class DialogProcessorService: processing_log.append(f"[REUSE] Destination audio file was not created: {dest_path}") else: processing_log.append(f"[REUSE] Destination audio file created: {dest_path}, size={dest_path.stat().st_size} bytes") - # Only include 'type' and 'path' so the concatenator always includes this segment - segment_results.append({ - "type": item_type, - "path": str(dest_path) - }) + results_map[segment_idx] = {"type": item_type, "path": str(dest_path)} processing_log.append(f"Reused existing audio for item {i+1}: copied from {src_audio_path} to {dest_path}") except Exception as e: error_message = f"Failed to copy reused audio for item {i+1}: {e}" processing_log.append(error_message) - segment_results.append({"type": "error", "message": error_message}) + results_map[segment_idx] = {"type": "error", "message": error_message} segment_idx += 1 continue else: error_message = f"Audio file for reuse not found at {src_audio_path} for item {i+1}." processing_log.append(error_message) - segment_results.append({"type": "error", "message": error_message}) + results_map[segment_idx] = {"type": "error", "message": error_message} segment_idx += 1 continue @@ -147,70 +193,81 @@ class DialogProcessorService: text = item.get("text") if not speaker_id or not text: processing_log.append(f"Skipping speech item {i+1} due to missing speaker_id or text.") - segment_results.append({"type": "error", "message": "Missing speaker_id or text"}) + results_map[segment_idx] = {"type": "error", "message": "Missing speaker_id or text"} + segment_idx += 1 continue - # Validate speaker_id and get speaker_sample_path speaker_info = self.speaker_service.get_speaker_by_id(speaker_id) if not speaker_info: processing_log.append(f"Speaker ID '{speaker_id}' not found. Skipping item {i+1}.") - segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' not found"}) + results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' not found"} + segment_idx += 1 continue if not speaker_info.sample_path: processing_log.append(f"Speaker ID '{speaker_id}' has no sample path defined. Skipping item {i+1}.") - segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"}) + results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"} + segment_idx += 1 continue - # speaker_info.sample_path is relative to config.SPEAKER_DATA_BASE_DIR abs_speaker_sample_path = config.SPEAKER_DATA_BASE_DIR / speaker_info.sample_path if not abs_speaker_sample_path.is_file(): processing_log.append(f"Speaker sample file not found or is not a file at '{abs_speaker_sample_path}' for speaker ID '{speaker_id}'. Skipping item {i+1}.") - segment_results.append({"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"}) + results_map[segment_idx] = {"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"} + segment_idx += 1 continue text_chunks = self._split_text(text) processing_log.append(f"Split text for speaker '{speaker_id}' into {len(text_chunks)} chunk(s).") for chunk_idx, text_chunk in enumerate(text_chunks): - segment_filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}" - processing_log.append(f"Generating speech for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'") - - try: - segment_output_path = await self.tts_service.generate_speech( - text=text_chunk, - speaker_id=speaker_id, # For metadata, actual sample path is used by TTS - speaker_sample_path=str(abs_speaker_sample_path), - output_filename_base=segment_filename_base, - output_dir=dialog_temp_dir, # Save to the dialog's temp dir - exaggeration=item.get('exaggeration', 0.5), # Default from Gradio, Pydantic model should provide this - cfg_weight=item.get('cfg_weight', 0.5), # Default from Gradio, Pydantic model should provide this - temperature=item.get('temperature', 0.8) # Default from Gradio, Pydantic model should provide this - ) - segment_results.append({ - "type": "speech", - "path": str(segment_output_path), - "speaker_id": speaker_id, - "text_chunk": text_chunk - }) - processing_log.append(f"Successfully generated segment: {segment_output_path}") - except Exception as e: - error_message = f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}" - processing_log.append(error_message) - segment_results.append({"type": "error", "message": error_message, "text_chunk": text_chunk}) + filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}" + processing_log.append(f"Queueing TTS for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'") + planned = { + "segment_idx": segment_idx, + "speaker_id": speaker_id, + "text_chunk": text_chunk, + "abs_speaker_sample_path": abs_speaker_sample_path, + "filename_base": filename_base, + "params": { + 'exaggeration': item.get('exaggeration', 0.5), + 'cfg_weight': item.get('cfg_weight', 0.5), + 'temperature': item.get('temperature', 0.8), + }, + } + tasks.append(asyncio.create_task(run_one(planned))) segment_idx += 1 - + elif item_type == "silence": duration = item.get("duration") if duration is None or duration < 0: processing_log.append(f"Skipping silence item {i+1} due to invalid duration.") - segment_results.append({"type": "error", "message": "Invalid duration for silence"}) + results_map[segment_idx] = {"type": "error", "message": "Invalid duration for silence"} + segment_idx += 1 continue - segment_results.append({"type": "silence", "duration": float(duration)}) + results_map[segment_idx] = {"type": "silence", "duration": float(duration)} processing_log.append(f"Added silence of {duration}s.") - + segment_idx += 1 + else: processing_log.append(f"Unknown item type '{item_type}' at item {i+1}. Skipping.") - segment_results.append({"type": "error", "message": f"Unknown item type: {item_type}"}) + results_map[segment_idx] = {"type": "error", "message": f"Unknown item type: {item_type}"} + segment_idx += 1 + + # Await all TTS tasks and merge results + if tasks: + processing_log.append( + f"Dispatching {len(tasks)} TTS task(s) with concurrency limit " + f"{getattr(config, 'TTS_MAX_CONCURRENCY', 2)}" + ) + completed = await asyncio.gather(*tasks, return_exceptions=False) + for idx, payload, maybe_log in completed: + results_map[idx] = payload + if maybe_log: + processing_log.append(maybe_log) + + # Build ordered list + for idx in sorted(results_map.keys()): + segment_results.append(results_map[idx]) # Log the full segment_results list for debugging processing_log.append("[DEBUG] Final segment_results list:") @@ -220,7 +277,7 @@ class DialogProcessorService: return { "log": "\n".join(processing_log), "segment_files": segment_results, - "temp_dir": str(dialog_temp_dir) # For cleanup or zipping later + "temp_dir": str(dialog_temp_dir) } if __name__ == "__main__": diff --git a/backend/app/services/tts_service.py b/backend/app/services/tts_service.py index 2b3f05d..8840a9a 100644 --- a/backend/app/services/tts_service.py +++ b/backend/app/services/tts_service.py @@ -1,11 +1,14 @@ import torch import torchaudio +import asyncio from typing import Optional from chatterbox.tts import ChatterboxTTS from pathlib import Path import gc # Garbage collector for memory management import os from contextlib import contextmanager +from datetime import datetime +import time # Import configuration try: @@ -114,42 +117,52 @@ class TTSService: # output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info) output_file_path = target_output_dir / f"{output_filename_base}.wav" - print(f"Generating audio for text: \"{text[:50]}...\" with speaker sample: {speaker_sample_path}") - wav = None + start_ts = datetime.now() + print(f"[{start_ts.isoformat(timespec='seconds')}] [TTS] START generate+save base={output_filename_base} len={len(text)} sample={speaker_sample_path}") try: - with torch.no_grad(): # Important for inference - wav = self.model.generate( - text=text, - audio_prompt_path=str(speaker_sample_p), # Must be a string path - exaggeration=exaggeration, - cfg_weight=cfg_weight, - temperature=temperature, - ) - - torchaudio.save(str(output_file_path), wav, self.model.sr) - print(f"Audio saved to: {output_file_path}") - return output_file_path - except Exception as e: - print(f"Error during TTS generation or saving: {e}") - raise - finally: - # Explicitly delete the wav tensor to free memory - if wav is not None: - del wav - - # Force garbage collection and cache cleanup - gc.collect() - if self.device == "cuda": - torch.cuda.empty_cache() - elif self.device == "mps": - if hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() - - # Unload the model if requested + def _gen_and_save() -> Path: + t0 = time.perf_counter() + wav = None + try: + with torch.no_grad(): # Important for inference + wav = self.model.generate( + text=text, + audio_prompt_path=str(speaker_sample_p), # Must be a string path + exaggeration=exaggeration, + cfg_weight=cfg_weight, + temperature=temperature, + ) + + # Save the audio synchronously in the same thread + torchaudio.save(str(output_file_path), wav, self.model.sr) + t1 = time.perf_counter() + print(f"[TTS-THREAD] Saved {output_file_path.name} in {t1 - t0:.2f}s") + return output_file_path + finally: + # Cleanup in the same thread that created the tensor + if wav is not None: + del wav + gc.collect() + if self.device == "cuda": + torch.cuda.empty_cache() + elif self.device == "mps": + if hasattr(torch.mps, "empty_cache"): + torch.mps.empty_cache() + + out_path = await asyncio.to_thread(_gen_and_save) + end_ts = datetime.now() + print(f"[{end_ts.isoformat(timespec='seconds')}] [TTS] END generate+save base={output_filename_base} dur={(end_ts - start_ts).total_seconds():.2f}s -> {out_path}") + + # Optionally unload model after generation if unload_after: print("Unloading TTS model after generation...") self.unload_model() + return out_path + except Exception as e: + print(f"Error during TTS generation or saving: {e}") + raise + # Example usage (for testing, not part of the service itself) if __name__ == "__main__": async def main_test(): diff --git a/frontend/index.html b/frontend/index.html index 065b902..a44da10 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -11,6 +11,29 @@