added back end concurrency and front end paste feature.
This commit is contained in:
parent
b28a9bcf58
commit
75a2a37252
|
@ -0,0 +1,188 @@
|
||||||
|
# Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan
|
||||||
|
|
||||||
|
Date: 2025-08-14
|
||||||
|
Owner: Backend
|
||||||
|
Status: Proposed (ready to implement)
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Increase GPU utilization and reduce wall-clock time for dialog generation.
|
||||||
|
- Keep model lifecycle stable (leveraging current `ModelManager`).
|
||||||
|
- Minimal-risk changes: no API shape changes to clients.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
- Implement bounded concurrency for per-line speech chunk generation within a single dialog request.
|
||||||
|
- Offload audio file writes to threads to overlap GPU compute and disk I/O.
|
||||||
|
- Add configuration knobs to tune concurrency.
|
||||||
|
|
||||||
|
## Current State (References)
|
||||||
|
|
||||||
|
- `backend/app/services/dialog_processor_service.py`
|
||||||
|
- `DialogProcessorService.process_dialog()` iterates items and awaits `tts_service.generate_speech(...)` sequentially (lines ~171–201).
|
||||||
|
- `backend/app/services/tts_service.py`
|
||||||
|
- `TTSService.generate_speech()` runs the TTS forward and calls `torchaudio.save(...)` on the event loop thread (blocking).
|
||||||
|
- `backend/app/services/model_manager.py`
|
||||||
|
- `ModelManager.using()` tracks active work; prevents idle eviction during requests.
|
||||||
|
- `backend/app/routers/dialog.py`
|
||||||
|
- `process_dialog_flow()` expects ordered `segment_files` and then concatenates; good to keep order stable.
|
||||||
|
|
||||||
|
## Design Overview
|
||||||
|
|
||||||
|
1) Bounded concurrency at dialog level
|
||||||
|
|
||||||
|
- Plan all output segments with a stable `segment_idx` (including speech chunks, silence, and reused audio).
|
||||||
|
- For speech chunks, schedule concurrent async tasks with a global semaphore set by config `TTS_MAX_CONCURRENCY` (start at 3–4).
|
||||||
|
- Await all tasks and collate results by `segment_idx` to preserve order.
|
||||||
|
|
||||||
|
2) File I/O offload
|
||||||
|
|
||||||
|
- Replace direct `torchaudio.save(...)` with `await asyncio.to_thread(torchaudio.save, ...)` in `TTSService.generate_speech()`.
|
||||||
|
- This lets the next GPU forward start while previous file writes happen on worker threads.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Add to `backend/app/config.py`:
|
||||||
|
|
||||||
|
- `TTS_MAX_CONCURRENCY: int` (default: `int(os.getenv("TTS_MAX_CONCURRENCY", "3"))`).
|
||||||
|
- Optional (future): `TTS_ENABLE_AMP_ON_CUDA: bool = True` to allow mixed precision on CUDA only.
|
||||||
|
|
||||||
|
## Implementation Steps
|
||||||
|
|
||||||
|
### A. Dialog-level concurrency
|
||||||
|
|
||||||
|
- File: `backend/app/services/dialog_processor_service.py`
|
||||||
|
- Function: `DialogProcessorService.process_dialog()`
|
||||||
|
|
||||||
|
1. Planning pass to assign indices
|
||||||
|
|
||||||
|
- Iterate `dialog_items` and build a list `planned_segments` entries:
|
||||||
|
- For silence or reuse: immediately append a final result with assigned `segment_idx` and continue.
|
||||||
|
- For speech: split into `text_chunks`; for each chunk create a planned entry: `{ segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }`.
|
||||||
|
- Increment `segment_idx` for every planned segment (speech chunk or silence/reuse) to preserve final order.
|
||||||
|
|
||||||
|
2. Concurrency setup
|
||||||
|
|
||||||
|
- Create `sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY)`.
|
||||||
|
- For each planned speech segment, create a task with an inner wrapper:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_one(planned):
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
out_path = await self.tts_service.generate_speech(
|
||||||
|
text=planned.text_chunk,
|
||||||
|
speaker_sample_path=planned.abs_speaker_sample_path,
|
||||||
|
output_filename_base=planned.filename_base,
|
||||||
|
output_dir=dialog_temp_dir,
|
||||||
|
exaggeration=planned.exaggeration,
|
||||||
|
cfg_weight=planned.cfg_weight,
|
||||||
|
temperature=planned.temperature,
|
||||||
|
)
|
||||||
|
return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk}
|
||||||
|
except Exception as e:
|
||||||
|
return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk}
|
||||||
|
```
|
||||||
|
|
||||||
|
- Schedule with `asyncio.create_task(run_one(p))` and collect tasks.
|
||||||
|
|
||||||
|
3. Await and collate
|
||||||
|
|
||||||
|
- `results_map = {}`; for each completed task, set `results_map[idx] = payload`.
|
||||||
|
- Merge: start with all previously final (silence/reuse/error) entries placed by `segment_idx`, then fill speech results by `segment_idx` into a single `segment_results` list sorted ascending by index.
|
||||||
|
- Keep `processing_log` entries for each planned segment (queued, started, finished, errors).
|
||||||
|
|
||||||
|
4. Return value unchanged
|
||||||
|
|
||||||
|
- Return `{"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}`. This maintains router and concatenator behavior.
|
||||||
|
|
||||||
|
### B. Offload audio writes
|
||||||
|
|
||||||
|
- File: `backend/app/services/tts_service.py`
|
||||||
|
- Function: `TTSService.generate_speech()`
|
||||||
|
|
||||||
|
1. After obtaining `wav` tensor, replace:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# torchaudio.save(str(output_file_path), wav, self.model.sr)
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Keep the rest of cleanup logic (delete `wav`, `gc.collect()`, cache emptying) unchanged.
|
||||||
|
|
||||||
|
2. Optional (CUDA-only AMP)
|
||||||
|
|
||||||
|
- If CUDA is used and `config.TTS_ENABLE_AMP_ON_CUDA` is True, wrap forward with AMP:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
|
wav = self.model.generate(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Leave MPS/CPU code path as-is.
|
||||||
|
|
||||||
|
## Error Handling & Ordering
|
||||||
|
|
||||||
|
- Every planned segment owns a unique `segment_idx`.
|
||||||
|
- On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already.
|
||||||
|
- Preserve exact output order expected by `routers/dialog.py::process_dialog_flow()`.
|
||||||
|
|
||||||
|
## Performance Expectations
|
||||||
|
|
||||||
|
- GPU util should increase from ~50% to 75–90% depending on dialog size and line lengths.
|
||||||
|
- Wall-clock reduction is workload-dependent; target 1.5–2.5x on multi-line dialogs.
|
||||||
|
|
||||||
|
## Metrics & Instrumentation
|
||||||
|
|
||||||
|
- Add timestamped log entries per segment: planned→queued→started→saved.
|
||||||
|
- Log effective concurrency (max in-flight), and cumulative GPU time if available.
|
||||||
|
- Optionally add a simple timing summary at end of `process_dialog()`.
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
1. Unit-ish
|
||||||
|
|
||||||
|
- Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist.
|
||||||
|
- Introduce an invalid speaker to verify error propagation doesn’t break the rest.
|
||||||
|
|
||||||
|
2. Integration
|
||||||
|
|
||||||
|
- POST `/api/dialog/generate` with 20–50 mixed-length lines and a couple silences.
|
||||||
|
- Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved.
|
||||||
|
- Compare runtime vs. sequential baseline (before/after).
|
||||||
|
|
||||||
|
3. Stress/limits
|
||||||
|
|
||||||
|
- Long lines split into many chunks; verify no OOM with `TTS_MAX_CONCURRENCY`=3.
|
||||||
|
- Try `TTS_MAX_CONCURRENCY`=1 to simulate sequential; compare metrics.
|
||||||
|
|
||||||
|
## Rollout & Config Defaults
|
||||||
|
|
||||||
|
- Default `TTS_MAX_CONCURRENCY=3`.
|
||||||
|
- Expose via environment variable; no client changes needed.
|
||||||
|
- If instability observed, set `TTS_MAX_CONCURRENCY=1` to revert to sequential behavior quickly.
|
||||||
|
|
||||||
|
## Risks & Mitigations
|
||||||
|
|
||||||
|
- OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place.
|
||||||
|
- Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency.
|
||||||
|
- Model thread safety → We call `model.generate` concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal.
|
||||||
|
|
||||||
|
## Follow-up (Out of Scope for this change)
|
||||||
|
|
||||||
|
- Dynamic batching queue inside `TTSService` for further GPU efficiency.
|
||||||
|
- CUDA AMP enablement and profiling.
|
||||||
|
- Per-speaker sub-queues if batching requires same-speaker inputs.
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
|
||||||
|
- `TTS_MAX_CONCURRENCY` is configurable; default=3.
|
||||||
|
- File writes occur via `asyncio.to_thread`.
|
||||||
|
- Order of `segment_files` unchanged relative to sequential output.
|
||||||
|
- End-to-end works for both small and large dialogs; error cases logged.
|
||||||
|
- Observed GPU utilization and runtime improve on representative dialog.
|
|
@ -67,6 +67,10 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
|
||||||
# Device configuration
|
# Device configuration
|
||||||
DEVICE = os.getenv("DEVICE", "auto")
|
DEVICE = os.getenv("DEVICE", "auto")
|
||||||
|
|
||||||
|
# Concurrency configuration
|
||||||
|
# Max number of concurrent TTS generation tasks per dialog request
|
||||||
|
TTS_MAX_CONCURRENCY = int(os.getenv("TTS_MAX_CONCURRENCY", "3"))
|
||||||
|
|
||||||
# Model idle eviction configuration
|
# Model idle eviction configuration
|
||||||
# Enable/disable idle-based model eviction
|
# Enable/disable idle-based model eviction
|
||||||
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
|
MODEL_EVICTION_ENABLED = os.getenv("MODEL_EVICTION_ENABLED", "true").lower() == "true"
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Union
|
from typing import List, Dict, Any, Union
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .tts_service import TTSService
|
from .tts_service import TTSService
|
||||||
from .speaker_service import SpeakerManagementService
|
from .speaker_service import SpeakerManagementService
|
||||||
|
@ -92,24 +94,72 @@ class DialogProcessorService:
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
segment_idx = 0
|
segment_idx = 0
|
||||||
|
tasks = []
|
||||||
|
results_map: Dict[int, Dict[str, Any]] = {}
|
||||||
|
sem = asyncio.Semaphore(getattr(config, "TTS_MAX_CONCURRENCY", 2))
|
||||||
|
|
||||||
|
async def run_one(planned: Dict[str, Any]):
|
||||||
|
async with sem:
|
||||||
|
text_chunk = planned["text_chunk"]
|
||||||
|
speaker_id = planned["speaker_id"]
|
||||||
|
abs_speaker_sample_path = planned["abs_speaker_sample_path"]
|
||||||
|
filename_base = planned["filename_base"]
|
||||||
|
params = planned["params"]
|
||||||
|
seg_idx = planned["segment_idx"]
|
||||||
|
start_ts = datetime.now()
|
||||||
|
start_line = (
|
||||||
|
f"[{start_ts.isoformat(timespec='seconds')}] [TTS-TASK] START seg_idx={seg_idx} "
|
||||||
|
f"speaker={speaker_id} chunk_len={len(text_chunk)} base={filename_base}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
out_path = await self.tts_service.generate_speech(
|
||||||
|
text=text_chunk,
|
||||||
|
speaker_id=speaker_id,
|
||||||
|
speaker_sample_path=str(abs_speaker_sample_path),
|
||||||
|
output_filename_base=filename_base,
|
||||||
|
output_dir=dialog_temp_dir,
|
||||||
|
exaggeration=params.get('exaggeration', 0.5),
|
||||||
|
cfg_weight=params.get('cfg_weight', 0.5),
|
||||||
|
temperature=params.get('temperature', 0.8),
|
||||||
|
)
|
||||||
|
end_ts = datetime.now()
|
||||||
|
duration = (end_ts - start_ts).total_seconds()
|
||||||
|
end_line = (
|
||||||
|
f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] END seg_idx={seg_idx} "
|
||||||
|
f"dur={duration:.2f}s -> {out_path}"
|
||||||
|
)
|
||||||
|
return seg_idx, {
|
||||||
|
"type": "speech",
|
||||||
|
"path": str(out_path),
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"text_chunk": text_chunk,
|
||||||
|
}, start_line + "\n" + f"Successfully generated segment: {out_path}" + "\n" + end_line
|
||||||
|
except Exception as e:
|
||||||
|
end_ts = datetime.now()
|
||||||
|
err_line = (
|
||||||
|
f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] ERROR seg_idx={seg_idx} "
|
||||||
|
f"speaker={speaker_id} err={repr(e)}"
|
||||||
|
)
|
||||||
|
return seg_idx, {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}",
|
||||||
|
"text_chunk": text_chunk,
|
||||||
|
}, err_line
|
||||||
|
|
||||||
for i, item in enumerate(dialog_items):
|
for i, item in enumerate(dialog_items):
|
||||||
item_type = item.get("type")
|
item_type = item.get("type")
|
||||||
processing_log.append(f"Processing item {i+1}: type='{item_type}'")
|
processing_log.append(f"Processing item {i+1}: type='{item_type}'")
|
||||||
|
|
||||||
# --- Universal: Handle reuse of existing audio for both speech and silence ---
|
# --- Handle reuse of existing audio ---
|
||||||
use_existing_audio = item.get("use_existing_audio", False)
|
use_existing_audio = item.get("use_existing_audio", False)
|
||||||
audio_url = item.get("audio_url")
|
audio_url = item.get("audio_url")
|
||||||
if use_existing_audio and audio_url:
|
if use_existing_audio and audio_url:
|
||||||
# Determine source path (handle both absolute and relative)
|
|
||||||
# Map web URL to actual file location in tts_generated_dialogs
|
|
||||||
if audio_url.startswith("/generated_audio/"):
|
if audio_url.startswith("/generated_audio/"):
|
||||||
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url[len("/generated_audio/"):]
|
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url[len("/generated_audio/"):]
|
||||||
else:
|
else:
|
||||||
src_audio_path = Path(audio_url)
|
src_audio_path = Path(audio_url)
|
||||||
if not src_audio_path.is_absolute():
|
if not src_audio_path.is_absolute():
|
||||||
# Assume relative to the generated audio root dir
|
|
||||||
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url.lstrip("/\\")
|
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url.lstrip("/\\")
|
||||||
# Now src_audio_path should point to the real file in tts_generated_dialogs
|
|
||||||
if src_audio_path.is_file():
|
if src_audio_path.is_file():
|
||||||
segment_filename = f"{output_base_name}_seg{segment_idx}_reused.wav"
|
segment_filename = f"{output_base_name}_seg{segment_idx}_reused.wav"
|
||||||
dest_path = (self.temp_audio_dir / output_base_name / segment_filename)
|
dest_path = (self.temp_audio_dir / output_base_name / segment_filename)
|
||||||
|
@ -123,22 +173,18 @@ class DialogProcessorService:
|
||||||
processing_log.append(f"[REUSE] Destination audio file was not created: {dest_path}")
|
processing_log.append(f"[REUSE] Destination audio file was not created: {dest_path}")
|
||||||
else:
|
else:
|
||||||
processing_log.append(f"[REUSE] Destination audio file created: {dest_path}, size={dest_path.stat().st_size} bytes")
|
processing_log.append(f"[REUSE] Destination audio file created: {dest_path}, size={dest_path.stat().st_size} bytes")
|
||||||
# Only include 'type' and 'path' so the concatenator always includes this segment
|
results_map[segment_idx] = {"type": item_type, "path": str(dest_path)}
|
||||||
segment_results.append({
|
|
||||||
"type": item_type,
|
|
||||||
"path": str(dest_path)
|
|
||||||
})
|
|
||||||
processing_log.append(f"Reused existing audio for item {i+1}: copied from {src_audio_path} to {dest_path}")
|
processing_log.append(f"Reused existing audio for item {i+1}: copied from {src_audio_path} to {dest_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Failed to copy reused audio for item {i+1}: {e}"
|
error_message = f"Failed to copy reused audio for item {i+1}: {e}"
|
||||||
processing_log.append(error_message)
|
processing_log.append(error_message)
|
||||||
segment_results.append({"type": "error", "message": error_message})
|
results_map[segment_idx] = {"type": "error", "message": error_message}
|
||||||
segment_idx += 1
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
error_message = f"Audio file for reuse not found at {src_audio_path} for item {i+1}."
|
error_message = f"Audio file for reuse not found at {src_audio_path} for item {i+1}."
|
||||||
processing_log.append(error_message)
|
processing_log.append(error_message)
|
||||||
segment_results.append({"type": "error", "message": error_message})
|
results_map[segment_idx] = {"type": "error", "message": error_message}
|
||||||
segment_idx += 1
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -147,70 +193,81 @@ class DialogProcessorService:
|
||||||
text = item.get("text")
|
text = item.get("text")
|
||||||
if not speaker_id or not text:
|
if not speaker_id or not text:
|
||||||
processing_log.append(f"Skipping speech item {i+1} due to missing speaker_id or text.")
|
processing_log.append(f"Skipping speech item {i+1} due to missing speaker_id or text.")
|
||||||
segment_results.append({"type": "error", "message": "Missing speaker_id or text"})
|
results_map[segment_idx] = {"type": "error", "message": "Missing speaker_id or text"}
|
||||||
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Validate speaker_id and get speaker_sample_path
|
|
||||||
speaker_info = self.speaker_service.get_speaker_by_id(speaker_id)
|
speaker_info = self.speaker_service.get_speaker_by_id(speaker_id)
|
||||||
if not speaker_info:
|
if not speaker_info:
|
||||||
processing_log.append(f"Speaker ID '{speaker_id}' not found. Skipping item {i+1}.")
|
processing_log.append(f"Speaker ID '{speaker_id}' not found. Skipping item {i+1}.")
|
||||||
segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' not found"})
|
results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' not found"}
|
||||||
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
if not speaker_info.sample_path:
|
if not speaker_info.sample_path:
|
||||||
processing_log.append(f"Speaker ID '{speaker_id}' has no sample path defined. Skipping item {i+1}.")
|
processing_log.append(f"Speaker ID '{speaker_id}' has no sample path defined. Skipping item {i+1}.")
|
||||||
segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"})
|
results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"}
|
||||||
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# speaker_info.sample_path is relative to config.SPEAKER_DATA_BASE_DIR
|
|
||||||
abs_speaker_sample_path = config.SPEAKER_DATA_BASE_DIR / speaker_info.sample_path
|
abs_speaker_sample_path = config.SPEAKER_DATA_BASE_DIR / speaker_info.sample_path
|
||||||
if not abs_speaker_sample_path.is_file():
|
if not abs_speaker_sample_path.is_file():
|
||||||
processing_log.append(f"Speaker sample file not found or is not a file at '{abs_speaker_sample_path}' for speaker ID '{speaker_id}'. Skipping item {i+1}.")
|
processing_log.append(f"Speaker sample file not found or is not a file at '{abs_speaker_sample_path}' for speaker ID '{speaker_id}'. Skipping item {i+1}.")
|
||||||
segment_results.append({"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"})
|
results_map[segment_idx] = {"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"}
|
||||||
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_chunks = self._split_text(text)
|
text_chunks = self._split_text(text)
|
||||||
processing_log.append(f"Split text for speaker '{speaker_id}' into {len(text_chunks)} chunk(s).")
|
processing_log.append(f"Split text for speaker '{speaker_id}' into {len(text_chunks)} chunk(s).")
|
||||||
|
|
||||||
for chunk_idx, text_chunk in enumerate(text_chunks):
|
for chunk_idx, text_chunk in enumerate(text_chunks):
|
||||||
segment_filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}"
|
filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}"
|
||||||
processing_log.append(f"Generating speech for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'")
|
processing_log.append(f"Queueing TTS for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'")
|
||||||
|
planned = {
|
||||||
try:
|
"segment_idx": segment_idx,
|
||||||
segment_output_path = await self.tts_service.generate_speech(
|
|
||||||
text=text_chunk,
|
|
||||||
speaker_id=speaker_id, # For metadata, actual sample path is used by TTS
|
|
||||||
speaker_sample_path=str(abs_speaker_sample_path),
|
|
||||||
output_filename_base=segment_filename_base,
|
|
||||||
output_dir=dialog_temp_dir, # Save to the dialog's temp dir
|
|
||||||
exaggeration=item.get('exaggeration', 0.5), # Default from Gradio, Pydantic model should provide this
|
|
||||||
cfg_weight=item.get('cfg_weight', 0.5), # Default from Gradio, Pydantic model should provide this
|
|
||||||
temperature=item.get('temperature', 0.8) # Default from Gradio, Pydantic model should provide this
|
|
||||||
)
|
|
||||||
segment_results.append({
|
|
||||||
"type": "speech",
|
|
||||||
"path": str(segment_output_path),
|
|
||||||
"speaker_id": speaker_id,
|
"speaker_id": speaker_id,
|
||||||
"text_chunk": text_chunk
|
"text_chunk": text_chunk,
|
||||||
})
|
"abs_speaker_sample_path": abs_speaker_sample_path,
|
||||||
processing_log.append(f"Successfully generated segment: {segment_output_path}")
|
"filename_base": filename_base,
|
||||||
except Exception as e:
|
"params": {
|
||||||
error_message = f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}"
|
'exaggeration': item.get('exaggeration', 0.5),
|
||||||
processing_log.append(error_message)
|
'cfg_weight': item.get('cfg_weight', 0.5),
|
||||||
segment_results.append({"type": "error", "message": error_message, "text_chunk": text_chunk})
|
'temperature': item.get('temperature', 0.8),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tasks.append(asyncio.create_task(run_one(planned)))
|
||||||
segment_idx += 1
|
segment_idx += 1
|
||||||
|
|
||||||
elif item_type == "silence":
|
elif item_type == "silence":
|
||||||
duration = item.get("duration")
|
duration = item.get("duration")
|
||||||
if duration is None or duration < 0:
|
if duration is None or duration < 0:
|
||||||
processing_log.append(f"Skipping silence item {i+1} due to invalid duration.")
|
processing_log.append(f"Skipping silence item {i+1} due to invalid duration.")
|
||||||
segment_results.append({"type": "error", "message": "Invalid duration for silence"})
|
results_map[segment_idx] = {"type": "error", "message": "Invalid duration for silence"}
|
||||||
|
segment_idx += 1
|
||||||
continue
|
continue
|
||||||
segment_results.append({"type": "silence", "duration": float(duration)})
|
results_map[segment_idx] = {"type": "silence", "duration": float(duration)}
|
||||||
processing_log.append(f"Added silence of {duration}s.")
|
processing_log.append(f"Added silence of {duration}s.")
|
||||||
|
segment_idx += 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
processing_log.append(f"Unknown item type '{item_type}' at item {i+1}. Skipping.")
|
processing_log.append(f"Unknown item type '{item_type}' at item {i+1}. Skipping.")
|
||||||
segment_results.append({"type": "error", "message": f"Unknown item type: {item_type}"})
|
results_map[segment_idx] = {"type": "error", "message": f"Unknown item type: {item_type}"}
|
||||||
|
segment_idx += 1
|
||||||
|
|
||||||
|
# Await all TTS tasks and merge results
|
||||||
|
if tasks:
|
||||||
|
processing_log.append(
|
||||||
|
f"Dispatching {len(tasks)} TTS task(s) with concurrency limit "
|
||||||
|
f"{getattr(config, 'TTS_MAX_CONCURRENCY', 2)}"
|
||||||
|
)
|
||||||
|
completed = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
for idx, payload, maybe_log in completed:
|
||||||
|
results_map[idx] = payload
|
||||||
|
if maybe_log:
|
||||||
|
processing_log.append(maybe_log)
|
||||||
|
|
||||||
|
# Build ordered list
|
||||||
|
for idx in sorted(results_map.keys()):
|
||||||
|
segment_results.append(results_map[idx])
|
||||||
|
|
||||||
# Log the full segment_results list for debugging
|
# Log the full segment_results list for debugging
|
||||||
processing_log.append("[DEBUG] Final segment_results list:")
|
processing_log.append("[DEBUG] Final segment_results list:")
|
||||||
|
@ -220,7 +277,7 @@ class DialogProcessorService:
|
||||||
return {
|
return {
|
||||||
"log": "\n".join(processing_log),
|
"log": "\n".join(processing_log),
|
||||||
"segment_files": segment_results,
|
"segment_files": segment_results,
|
||||||
"temp_dir": str(dialog_temp_dir) # For cleanup or zipping later
|
"temp_dir": str(dialog_temp_dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from chatterbox.tts import ChatterboxTTS
|
from chatterbox.tts import ChatterboxTTS
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gc # Garbage collector for memory management
|
import gc # Garbage collector for memory management
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
# Import configuration
|
# Import configuration
|
||||||
try:
|
try:
|
||||||
|
@ -114,7 +117,11 @@ class TTSService:
|
||||||
# output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info)
|
# output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info)
|
||||||
output_file_path = target_output_dir / f"{output_filename_base}.wav"
|
output_file_path = target_output_dir / f"{output_filename_base}.wav"
|
||||||
|
|
||||||
print(f"Generating audio for text: \"{text[:50]}...\" with speaker sample: {speaker_sample_path}")
|
start_ts = datetime.now()
|
||||||
|
print(f"[{start_ts.isoformat(timespec='seconds')}] [TTS] START generate+save base={output_filename_base} len={len(text)} sample={speaker_sample_path}")
|
||||||
|
try:
|
||||||
|
def _gen_and_save() -> Path:
|
||||||
|
t0 = time.perf_counter()
|
||||||
wav = None
|
wav = None
|
||||||
try:
|
try:
|
||||||
with torch.no_grad(): # Important for inference
|
with torch.no_grad(): # Important for inference
|
||||||
|
@ -126,18 +133,15 @@ class TTSService:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save the audio synchronously in the same thread
|
||||||
torchaudio.save(str(output_file_path), wav, self.model.sr)
|
torchaudio.save(str(output_file_path), wav, self.model.sr)
|
||||||
print(f"Audio saved to: {output_file_path}")
|
t1 = time.perf_counter()
|
||||||
|
print(f"[TTS-THREAD] Saved {output_file_path.name} in {t1 - t0:.2f}s")
|
||||||
return output_file_path
|
return output_file_path
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during TTS generation or saving: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
finally:
|
||||||
# Explicitly delete the wav tensor to free memory
|
# Cleanup in the same thread that created the tensor
|
||||||
if wav is not None:
|
if wav is not None:
|
||||||
del wav
|
del wav
|
||||||
|
|
||||||
# Force garbage collection and cache cleanup
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -145,11 +149,20 @@ class TTSService:
|
||||||
if hasattr(torch.mps, "empty_cache"):
|
if hasattr(torch.mps, "empty_cache"):
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
# Unload the model if requested
|
out_path = await asyncio.to_thread(_gen_and_save)
|
||||||
|
end_ts = datetime.now()
|
||||||
|
print(f"[{end_ts.isoformat(timespec='seconds')}] [TTS] END generate+save base={output_filename_base} dur={(end_ts - start_ts).total_seconds():.2f}s -> {out_path}")
|
||||||
|
|
||||||
|
# Optionally unload model after generation
|
||||||
if unload_after:
|
if unload_after:
|
||||||
print("Unloading TTS model after generation...")
|
print("Unloading TTS model after generation...")
|
||||||
self.unload_model()
|
self.unload_model()
|
||||||
|
|
||||||
|
return out_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during TTS generation or saving: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Example usage (for testing, not part of the service itself)
|
# Example usage (for testing, not part of the service itself)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
async def main_test():
|
async def main_test():
|
||||||
|
|
|
@ -11,6 +11,29 @@
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<h1>Chatterbox TTS</h1>
|
<h1>Chatterbox TTS</h1>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Paste Script Modal -->
|
||||||
|
<div id="paste-script-modal" class="modal" style="display: none;">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<h3>Paste Dialog Script</h3>
|
||||||
|
<button class="modal-close" id="paste-script-close">×</button>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
<p>Paste JSONL content (one JSON object per line). Example lines:</p>
|
||||||
|
<pre style="white-space:pre-wrap; background:#f6f8fa; padding:8px; border-radius:4px;">
|
||||||
|
{"type":"speech","speaker_id":"alice","text":"Hello there!"}
|
||||||
|
{"type":"silence","duration":0.5}
|
||||||
|
{"type":"speech","speaker_id":"bob","text":"Hi!"}
|
||||||
|
</pre>
|
||||||
|
<textarea id="paste-script-text" rows="10" style="width:100%;" placeholder='Paste JSONL here'></textarea>
|
||||||
|
</div>
|
||||||
|
<div class="modal-footer">
|
||||||
|
<button id="paste-script-load" class="btn-primary">Load</button>
|
||||||
|
<button id="paste-script-cancel" class="btn-secondary">Cancel</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<!-- Global inline notification area -->
|
<!-- Global inline notification area -->
|
||||||
|
@ -55,6 +78,7 @@
|
||||||
<button id="save-script-btn">Save Script</button>
|
<button id="save-script-btn">Save Script</button>
|
||||||
<input type="file" id="load-script-input" accept=".jsonl" style="display: none;">
|
<input type="file" id="load-script-input" accept=".jsonl" style="display: none;">
|
||||||
<button id="load-script-btn">Load Script</button>
|
<button id="load-script-btn">Load Script</button>
|
||||||
|
<button id="paste-script-btn">Paste Script</button>
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -201,6 +201,12 @@ async function initializeDialogEditor() {
|
||||||
const saveScriptBtn = document.getElementById('save-script-btn');
|
const saveScriptBtn = document.getElementById('save-script-btn');
|
||||||
const loadScriptBtn = document.getElementById('load-script-btn');
|
const loadScriptBtn = document.getElementById('load-script-btn');
|
||||||
const loadScriptInput = document.getElementById('load-script-input');
|
const loadScriptInput = document.getElementById('load-script-input');
|
||||||
|
const pasteScriptBtn = document.getElementById('paste-script-btn');
|
||||||
|
const pasteModal = document.getElementById('paste-script-modal');
|
||||||
|
const pasteText = document.getElementById('paste-script-text');
|
||||||
|
const pasteLoadBtn = document.getElementById('paste-script-load');
|
||||||
|
const pasteCancelBtn = document.getElementById('paste-script-cancel');
|
||||||
|
const pasteCloseBtn = document.getElementById('paste-script-close');
|
||||||
|
|
||||||
// Results Display Elements
|
// Results Display Elements
|
||||||
const generationLogPre = document.getElementById('generation-log-content'); // Corrected ID
|
const generationLogPre = document.getElementById('generation-log-content'); // Corrected ID
|
||||||
|
@ -891,6 +897,71 @@ async function initializeDialogEditor() {
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load dialog script from pasted JSONL text
|
||||||
|
async function loadDialogScriptFromText(text) {
|
||||||
|
if (!text || !text.trim()) {
|
||||||
|
showNotice('Please paste JSONL content to load.', 'warning', { timeout: 4000 });
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const lines = text.trim().split('\n');
|
||||||
|
const loadedItems = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < lines.length; i++) {
|
||||||
|
const line = lines[i].trim();
|
||||||
|
if (!line) continue; // Skip empty lines
|
||||||
|
try {
|
||||||
|
const item = JSON.parse(line);
|
||||||
|
const validatedItem = validateDialogItem(item, i + 1);
|
||||||
|
if (validatedItem) {
|
||||||
|
loadedItems.push(normalizeDialogItem(validatedItem));
|
||||||
|
}
|
||||||
|
} catch (parseError) {
|
||||||
|
console.error(`Error parsing line ${i + 1}:`, parseError);
|
||||||
|
showNotice(`Error parsing line ${i + 1}: ${parseError.message}`, 'error');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (loadedItems.length === 0) {
|
||||||
|
showNotice('No valid dialog items found in the pasted content.', 'warning', { timeout: 4000 });
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm replacement if existing items
|
||||||
|
if (dialogItems.length > 0) {
|
||||||
|
const confirmed = await confirmAction(
|
||||||
|
`This will replace your current dialog (${dialogItems.length} items) with the pasted script (${loadedItems.length} items). Continue?`
|
||||||
|
);
|
||||||
|
if (!confirmed) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure speakers are loaded before rendering
|
||||||
|
if (availableSpeakersCache.length === 0) {
|
||||||
|
try {
|
||||||
|
availableSpeakersCache = await getSpeakers();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching speakers:', error);
|
||||||
|
showNotice('Could not load speakers. Dialog loaded but speaker names may not display correctly.', 'warning', { timeout: 5000 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace current dialog
|
||||||
|
dialogItems.splice(0, dialogItems.length, ...loadedItems);
|
||||||
|
// Persist loaded script
|
||||||
|
saveDialogToLocalStorage();
|
||||||
|
renderDialogItems();
|
||||||
|
|
||||||
|
console.log(`Loaded ${loadedItems.length} dialog items from pasted text`);
|
||||||
|
showNotice(`Successfully loaded ${loadedItems.length} dialog items.`, 'success', { timeout: 3000 });
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading dialog script from text:', error);
|
||||||
|
showNotice(`Error loading dialog script: ${error.message}`, 'error');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function validateDialogItem(item, lineNumber) {
|
function validateDialogItem(item, lineNumber) {
|
||||||
if (!item || typeof item !== 'object') {
|
if (!item || typeof item !== 'object') {
|
||||||
throw new Error(`Line ${lineNumber}: Invalid item format`);
|
throw new Error(`Line ${lineNumber}: Invalid item format`);
|
||||||
|
@ -946,14 +1017,41 @@ async function initializeDialogEditor() {
|
||||||
const file = e.target.files[0];
|
const file = e.target.files[0];
|
||||||
if (file) {
|
if (file) {
|
||||||
loadDialogScript(file);
|
loadDialogScript(file);
|
||||||
// Reset input so same file can be loaded again
|
|
||||||
e.target.value = '';
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Paste Script (JSONL) Modal Handlers ---
|
||||||
|
if (pasteScriptBtn && pasteModal && pasteText && pasteLoadBtn && pasteCancelBtn && pasteCloseBtn) {
|
||||||
|
let escHandler = null;
|
||||||
|
const closePasteModal = () => {
|
||||||
|
pasteModal.style.display = 'none';
|
||||||
|
pasteLoadBtn.onclick = null;
|
||||||
|
pasteCancelBtn.onclick = null;
|
||||||
|
pasteCloseBtn.onclick = null;
|
||||||
|
pasteModal.onclick = null;
|
||||||
|
if (escHandler) {
|
||||||
|
document.removeEventListener('keydown', escHandler);
|
||||||
|
escHandler = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const openPasteModal = () => {
|
||||||
|
pasteText.value = '';
|
||||||
|
pasteModal.style.display = 'flex';
|
||||||
|
escHandler = (e) => { if (e.key === 'Escape') closePasteModal(); };
|
||||||
|
document.addEventListener('keydown', escHandler);
|
||||||
|
pasteModal.onclick = (e) => { if (e.target === pasteModal) closePasteModal(); };
|
||||||
|
pasteCloseBtn.onclick = closePasteModal;
|
||||||
|
pasteCancelBtn.onclick = closePasteModal;
|
||||||
|
pasteLoadBtn.onclick = async () => {
|
||||||
|
const ok = await loadDialogScriptFromText(pasteText.value);
|
||||||
|
if (ok) closePasteModal();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
pasteScriptBtn.addEventListener('click', openPasteModal);
|
||||||
|
}
|
||||||
|
|
||||||
// --- Clear Dialog Button ---
|
// --- Clear Dialog Button ---
|
||||||
// Prefer an existing button with id if present; otherwise, create and insert beside Save/Load.
|
|
||||||
let clearDialogBtn = document.getElementById('clear-dialog-btn');
|
let clearDialogBtn = document.getElementById('clear-dialog-btn');
|
||||||
if (!clearDialogBtn) {
|
if (!clearDialogBtn) {
|
||||||
clearDialogBtn = document.createElement('button');
|
clearDialogBtn = document.createElement('button');
|
||||||
|
@ -970,6 +1068,7 @@ async function initializeDialogEditor() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (clearDialogBtn) {
|
||||||
clearDialogBtn.addEventListener('click', async () => {
|
clearDialogBtn.addEventListener('click', async () => {
|
||||||
if (dialogItems.length === 0) {
|
if (dialogItems.length === 0) {
|
||||||
showNotice('Dialog is already empty.', 'info', { timeout: 2500 });
|
showNotice('Dialog is already empty.', 'info', { timeout: 2500 });
|
||||||
|
@ -985,6 +1084,7 @@ async function initializeDialogEditor() {
|
||||||
renderDialogItems();
|
renderDialogItems();
|
||||||
showNotice('Dialog cleared.', 'success', { timeout: 2500 });
|
showNotice('Dialog cleared.', 'success', { timeout: 2500 });
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
console.log('Dialog Editor Initialized');
|
console.log('Dialog Editor Initialized');
|
||||||
renderDialogItems(); // Initial render (empty)
|
renderDialogItems(); // Initial render (empty)
|
||||||
|
|
Loading…
Reference in New Issue