added back end concurrency and front end paste feature.
This commit is contained in:
parent
b28a9bcf58
commit
75a2a37252
|
@ -0,0 +1,188 @@
|
|||
# Chatterbox TTS Backend: Bounded Concurrency + File I/O Offload Plan
|
||||
|
||||
Date: 2025-08-14
|
||||
Owner: Backend
|
||||
Status: Proposed (ready to implement)
|
||||
|
||||
## Goals
|
||||
|
||||
- Increase GPU utilization and reduce wall-clock time for dialog generation.
|
||||
- Keep model lifecycle stable (leveraging current `ModelManager`).
|
||||
- Minimal-risk changes: no API shape changes to clients.
|
||||
|
||||
## Scope
|
||||
|
||||
- Implement bounded concurrency for per-line speech chunk generation within a single dialog request.
|
||||
- Offload audio file writes to threads to overlap GPU compute and disk I/O.
|
||||
- Add configuration knobs to tune concurrency.
|
||||
|
||||
## Current State (References)
|
||||
|
||||
- `backend/app/services/dialog_processor_service.py`
|
||||
- `DialogProcessorService.process_dialog()` iterates items and awaits `tts_service.generate_speech(...)` sequentially (lines ~171–201).
|
||||
- `backend/app/services/tts_service.py`
|
||||
- `TTSService.generate_speech()` runs the TTS forward and calls `torchaudio.save(...)` on the event loop thread (blocking).
|
||||
- `backend/app/services/model_manager.py`
|
||||
- `ModelManager.using()` tracks active work; prevents idle eviction during requests.
|
||||
- `backend/app/routers/dialog.py`
|
||||
- `process_dialog_flow()` expects ordered `segment_files` and then concatenates; good to keep order stable.
|
||||
|
||||
## Design Overview
|
||||
|
||||
1) Bounded concurrency at dialog level
|
||||
|
||||
- Plan all output segments with a stable `segment_idx` (including speech chunks, silence, and reused audio).
|
||||
- For speech chunks, schedule concurrent async tasks with a global semaphore set by config `TTS_MAX_CONCURRENCY` (start at 3–4).
|
||||
- Await all tasks and collate results by `segment_idx` to preserve order.
|
||||
|
||||
2) File I/O offload
|
||||
|
||||
- Replace direct `torchaudio.save(...)` with `await asyncio.to_thread(torchaudio.save, ...)` in `TTSService.generate_speech()`.
|
||||
- This lets the next GPU forward start while previous file writes happen on worker threads.
|
||||
|
||||
## Configuration
|
||||
|
||||
Add to `backend/app/config.py`:
|
||||
|
||||
- `TTS_MAX_CONCURRENCY: int` (default: `int(os.getenv("TTS_MAX_CONCURRENCY", "3"))`).
|
||||
- Optional (future): `TTS_ENABLE_AMP_ON_CUDA: bool = True` to allow mixed precision on CUDA only.
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### A. Dialog-level concurrency
|
||||
|
||||
- File: `backend/app/services/dialog_processor_service.py`
|
||||
- Function: `DialogProcessorService.process_dialog()`
|
||||
|
||||
1. Planning pass to assign indices
|
||||
|
||||
- Iterate `dialog_items` and build a list `planned_segments` entries:
|
||||
- For silence or reuse: immediately append a final result with assigned `segment_idx` and continue.
|
||||
- For speech: split into `text_chunks`; for each chunk create a planned entry: `{ segment_idx, type: 'speech', speaker_id, text_chunk, abs_speaker_sample_path, tts_params }`.
|
||||
- Increment `segment_idx` for every planned segment (speech chunk or silence/reuse) to preserve final order.
|
||||
|
||||
2. Concurrency setup
|
||||
|
||||
- Create `sem = asyncio.Semaphore(config.TTS_MAX_CONCURRENCY)`.
|
||||
- For each planned speech segment, create a task with an inner wrapper:
|
||||
|
||||
```python
|
||||
async def run_one(planned):
|
||||
async with sem:
|
||||
try:
|
||||
out_path = await self.tts_service.generate_speech(
|
||||
text=planned.text_chunk,
|
||||
speaker_sample_path=planned.abs_speaker_sample_path,
|
||||
output_filename_base=planned.filename_base,
|
||||
output_dir=dialog_temp_dir,
|
||||
exaggeration=planned.exaggeration,
|
||||
cfg_weight=planned.cfg_weight,
|
||||
temperature=planned.temperature,
|
||||
)
|
||||
return planned.segment_idx, {"type": "speech", "path": str(out_path), "speaker_id": planned.speaker_id, "text_chunk": planned.text_chunk}
|
||||
except Exception as e:
|
||||
return planned.segment_idx, {"type": "error", "message": f"Error generating speech: {e}", "text_chunk": planned.text_chunk}
|
||||
```
|
||||
|
||||
- Schedule with `asyncio.create_task(run_one(p))` and collect tasks.
|
||||
|
||||
3. Await and collate
|
||||
|
||||
- `results_map = {}`; for each completed task, set `results_map[idx] = payload`.
|
||||
- Merge: start with all previously final (silence/reuse/error) entries placed by `segment_idx`, then fill speech results by `segment_idx` into a single `segment_results` list sorted ascending by index.
|
||||
- Keep `processing_log` entries for each planned segment (queued, started, finished, errors).
|
||||
|
||||
4. Return value unchanged
|
||||
|
||||
- Return `{"log": ..., "segment_files": segment_results, "temp_dir": str(dialog_temp_dir)}`. This maintains router and concatenator behavior.
|
||||
|
||||
### B. Offload audio writes
|
||||
|
||||
- File: `backend/app/services/tts_service.py`
|
||||
- Function: `TTSService.generate_speech()`
|
||||
|
||||
1. After obtaining `wav` tensor, replace:
|
||||
|
||||
```python
|
||||
# torchaudio.save(str(output_file_path), wav, self.model.sr)
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```python
|
||||
await asyncio.to_thread(torchaudio.save, str(output_file_path), wav, self.model.sr)
|
||||
```
|
||||
|
||||
- Keep the rest of cleanup logic (delete `wav`, `gc.collect()`, cache emptying) unchanged.
|
||||
|
||||
2. Optional (CUDA-only AMP)
|
||||
|
||||
- If CUDA is used and `config.TTS_ENABLE_AMP_ON_CUDA` is True, wrap forward with AMP:
|
||||
|
||||
```python
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
wav = self.model.generate(...)
|
||||
```
|
||||
|
||||
- Leave MPS/CPU code path as-is.
|
||||
|
||||
## Error Handling & Ordering
|
||||
|
||||
- Every planned segment owns a unique `segment_idx`.
|
||||
- On failure, insert an error record at that index; downstream concatenation will skip missing/nonexistent paths already.
|
||||
- Preserve exact output order expected by `routers/dialog.py::process_dialog_flow()`.
|
||||
|
||||
## Performance Expectations
|
||||
|
||||
- GPU util should increase from ~50% to 75–90% depending on dialog size and line lengths.
|
||||
- Wall-clock reduction is workload-dependent; target 1.5–2.5x on multi-line dialogs.
|
||||
|
||||
## Metrics & Instrumentation
|
||||
|
||||
- Add timestamped log entries per segment: planned→queued→started→saved.
|
||||
- Log effective concurrency (max in-flight), and cumulative GPU time if available.
|
||||
- Optionally add a simple timing summary at end of `process_dialog()`.
|
||||
|
||||
## Testing Plan
|
||||
|
||||
1. Unit-ish
|
||||
|
||||
- Small dialog (3 speech lines, 1 silence). Ensure ordering is stable and files exist.
|
||||
- Introduce an invalid speaker to verify error propagation doesn’t break the rest.
|
||||
|
||||
2. Integration
|
||||
|
||||
- POST `/api/dialog/generate` with 20–50 mixed-length lines and a couple silences.
|
||||
- Validate: response OK, concatenated file exists, zip contains all generated speech segments, order preserved.
|
||||
- Compare runtime vs. sequential baseline (before/after).
|
||||
|
||||
3. Stress/limits
|
||||
|
||||
- Long lines split into many chunks; verify no OOM with `TTS_MAX_CONCURRENCY`=3.
|
||||
- Try `TTS_MAX_CONCURRENCY`=1 to simulate sequential; compare metrics.
|
||||
|
||||
## Rollout & Config Defaults
|
||||
|
||||
- Default `TTS_MAX_CONCURRENCY=3`.
|
||||
- Expose via environment variable; no client changes needed.
|
||||
- If instability observed, set `TTS_MAX_CONCURRENCY=1` to revert to sequential behavior quickly.
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
- OOM under high concurrency → Mitigate with low default, easy rollback, and chunking already in place.
|
||||
- Disk I/O saturation → Offload to threads; if disk is a bottleneck, decrease concurrency.
|
||||
- Model thread safety → We call `model.generate` concurrently only up to semaphore cap; if underlying library is not thread-safe for forward passes, consider serializing forwards but still overlapping with file I/O; early logs will reveal.
|
||||
|
||||
## Follow-up (Out of Scope for this change)
|
||||
|
||||
- Dynamic batching queue inside `TTSService` for further GPU efficiency.
|
||||
- CUDA AMP enablement and profiling.
|
||||
- Per-speaker sub-queues if batching requires same-speaker inputs.
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- `TTS_MAX_CONCURRENCY` is configurable; default=3.
|
||||
- File writes occur via `asyncio.to_thread`.
|
||||
- Order of `segment_files` unchanged relative to sequential output.
|
||||
- End-to-end works for both small and large dialogs; error cases logged.
|
||||
- Observed GPU utilization and runtime improve on representative dialog.
|
|
@ -66,6 +66,10 @@ if CORS_ORIGINS != ["*"] and _frontend_host and _frontend_port:
|
|||
|
||||
# Device configuration
|
||||
DEVICE = os.getenv("DEVICE", "auto")
|
||||
|
||||
# Concurrency configuration
|
||||
# Max number of concurrent TTS generation tasks per dialog request
|
||||
TTS_MAX_CONCURRENCY = int(os.getenv("TTS_MAX_CONCURRENCY", "3"))
|
||||
|
||||
# Model idle eviction configuration
|
||||
# Enable/disable idle-based model eviction
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Union
|
||||
import re
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from .tts_service import TTSService
|
||||
from .speaker_service import SpeakerManagementService
|
||||
|
@ -92,24 +94,72 @@ class DialogProcessorService:
|
|||
|
||||
import shutil
|
||||
segment_idx = 0
|
||||
tasks = []
|
||||
results_map: Dict[int, Dict[str, Any]] = {}
|
||||
sem = asyncio.Semaphore(getattr(config, "TTS_MAX_CONCURRENCY", 2))
|
||||
|
||||
async def run_one(planned: Dict[str, Any]):
|
||||
async with sem:
|
||||
text_chunk = planned["text_chunk"]
|
||||
speaker_id = planned["speaker_id"]
|
||||
abs_speaker_sample_path = planned["abs_speaker_sample_path"]
|
||||
filename_base = planned["filename_base"]
|
||||
params = planned["params"]
|
||||
seg_idx = planned["segment_idx"]
|
||||
start_ts = datetime.now()
|
||||
start_line = (
|
||||
f"[{start_ts.isoformat(timespec='seconds')}] [TTS-TASK] START seg_idx={seg_idx} "
|
||||
f"speaker={speaker_id} chunk_len={len(text_chunk)} base={filename_base}"
|
||||
)
|
||||
try:
|
||||
out_path = await self.tts_service.generate_speech(
|
||||
text=text_chunk,
|
||||
speaker_id=speaker_id,
|
||||
speaker_sample_path=str(abs_speaker_sample_path),
|
||||
output_filename_base=filename_base,
|
||||
output_dir=dialog_temp_dir,
|
||||
exaggeration=params.get('exaggeration', 0.5),
|
||||
cfg_weight=params.get('cfg_weight', 0.5),
|
||||
temperature=params.get('temperature', 0.8),
|
||||
)
|
||||
end_ts = datetime.now()
|
||||
duration = (end_ts - start_ts).total_seconds()
|
||||
end_line = (
|
||||
f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] END seg_idx={seg_idx} "
|
||||
f"dur={duration:.2f}s -> {out_path}"
|
||||
)
|
||||
return seg_idx, {
|
||||
"type": "speech",
|
||||
"path": str(out_path),
|
||||
"speaker_id": speaker_id,
|
||||
"text_chunk": text_chunk,
|
||||
}, start_line + "\n" + f"Successfully generated segment: {out_path}" + "\n" + end_line
|
||||
except Exception as e:
|
||||
end_ts = datetime.now()
|
||||
err_line = (
|
||||
f"[{end_ts.isoformat(timespec='seconds')}] [TTS-TASK] ERROR seg_idx={seg_idx} "
|
||||
f"speaker={speaker_id} err={repr(e)}"
|
||||
)
|
||||
return seg_idx, {
|
||||
"type": "error",
|
||||
"message": f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}",
|
||||
"text_chunk": text_chunk,
|
||||
}, err_line
|
||||
|
||||
for i, item in enumerate(dialog_items):
|
||||
item_type = item.get("type")
|
||||
processing_log.append(f"Processing item {i+1}: type='{item_type}'")
|
||||
|
||||
# --- Universal: Handle reuse of existing audio for both speech and silence ---
|
||||
# --- Handle reuse of existing audio ---
|
||||
use_existing_audio = item.get("use_existing_audio", False)
|
||||
audio_url = item.get("audio_url")
|
||||
if use_existing_audio and audio_url:
|
||||
# Determine source path (handle both absolute and relative)
|
||||
# Map web URL to actual file location in tts_generated_dialogs
|
||||
if audio_url.startswith("/generated_audio/"):
|
||||
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url[len("/generated_audio/"):]
|
||||
else:
|
||||
src_audio_path = Path(audio_url)
|
||||
if not src_audio_path.is_absolute():
|
||||
# Assume relative to the generated audio root dir
|
||||
src_audio_path = config.DIALOG_OUTPUT_DIR / audio_url.lstrip("/\\")
|
||||
# Now src_audio_path should point to the real file in tts_generated_dialogs
|
||||
if src_audio_path.is_file():
|
||||
segment_filename = f"{output_base_name}_seg{segment_idx}_reused.wav"
|
||||
dest_path = (self.temp_audio_dir / output_base_name / segment_filename)
|
||||
|
@ -123,22 +173,18 @@ class DialogProcessorService:
|
|||
processing_log.append(f"[REUSE] Destination audio file was not created: {dest_path}")
|
||||
else:
|
||||
processing_log.append(f"[REUSE] Destination audio file created: {dest_path}, size={dest_path.stat().st_size} bytes")
|
||||
# Only include 'type' and 'path' so the concatenator always includes this segment
|
||||
segment_results.append({
|
||||
"type": item_type,
|
||||
"path": str(dest_path)
|
||||
})
|
||||
results_map[segment_idx] = {"type": item_type, "path": str(dest_path)}
|
||||
processing_log.append(f"Reused existing audio for item {i+1}: copied from {src_audio_path} to {dest_path}")
|
||||
except Exception as e:
|
||||
error_message = f"Failed to copy reused audio for item {i+1}: {e}"
|
||||
processing_log.append(error_message)
|
||||
segment_results.append({"type": "error", "message": error_message})
|
||||
results_map[segment_idx] = {"type": "error", "message": error_message}
|
||||
segment_idx += 1
|
||||
continue
|
||||
else:
|
||||
error_message = f"Audio file for reuse not found at {src_audio_path} for item {i+1}."
|
||||
processing_log.append(error_message)
|
||||
segment_results.append({"type": "error", "message": error_message})
|
||||
results_map[segment_idx] = {"type": "error", "message": error_message}
|
||||
segment_idx += 1
|
||||
continue
|
||||
|
||||
|
@ -147,70 +193,81 @@ class DialogProcessorService:
|
|||
text = item.get("text")
|
||||
if not speaker_id or not text:
|
||||
processing_log.append(f"Skipping speech item {i+1} due to missing speaker_id or text.")
|
||||
segment_results.append({"type": "error", "message": "Missing speaker_id or text"})
|
||||
results_map[segment_idx] = {"type": "error", "message": "Missing speaker_id or text"}
|
||||
segment_idx += 1
|
||||
continue
|
||||
|
||||
# Validate speaker_id and get speaker_sample_path
|
||||
speaker_info = self.speaker_service.get_speaker_by_id(speaker_id)
|
||||
if not speaker_info:
|
||||
processing_log.append(f"Speaker ID '{speaker_id}' not found. Skipping item {i+1}.")
|
||||
segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' not found"})
|
||||
results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' not found"}
|
||||
segment_idx += 1
|
||||
continue
|
||||
if not speaker_info.sample_path:
|
||||
processing_log.append(f"Speaker ID '{speaker_id}' has no sample path defined. Skipping item {i+1}.")
|
||||
segment_results.append({"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"})
|
||||
results_map[segment_idx] = {"type": "error", "message": f"Speaker ID '{speaker_id}' has no sample path defined"}
|
||||
segment_idx += 1
|
||||
continue
|
||||
|
||||
# speaker_info.sample_path is relative to config.SPEAKER_DATA_BASE_DIR
|
||||
abs_speaker_sample_path = config.SPEAKER_DATA_BASE_DIR / speaker_info.sample_path
|
||||
if not abs_speaker_sample_path.is_file():
|
||||
processing_log.append(f"Speaker sample file not found or is not a file at '{abs_speaker_sample_path}' for speaker ID '{speaker_id}'. Skipping item {i+1}.")
|
||||
segment_results.append({"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"})
|
||||
results_map[segment_idx] = {"type": "error", "message": f"Speaker sample not a file or not found: {abs_speaker_sample_path}"}
|
||||
segment_idx += 1
|
||||
continue
|
||||
|
||||
text_chunks = self._split_text(text)
|
||||
processing_log.append(f"Split text for speaker '{speaker_id}' into {len(text_chunks)} chunk(s).")
|
||||
|
||||
for chunk_idx, text_chunk in enumerate(text_chunks):
|
||||
segment_filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}"
|
||||
processing_log.append(f"Generating speech for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'")
|
||||
|
||||
try:
|
||||
segment_output_path = await self.tts_service.generate_speech(
|
||||
text=text_chunk,
|
||||
speaker_id=speaker_id, # For metadata, actual sample path is used by TTS
|
||||
speaker_sample_path=str(abs_speaker_sample_path),
|
||||
output_filename_base=segment_filename_base,
|
||||
output_dir=dialog_temp_dir, # Save to the dialog's temp dir
|
||||
exaggeration=item.get('exaggeration', 0.5), # Default from Gradio, Pydantic model should provide this
|
||||
cfg_weight=item.get('cfg_weight', 0.5), # Default from Gradio, Pydantic model should provide this
|
||||
temperature=item.get('temperature', 0.8) # Default from Gradio, Pydantic model should provide this
|
||||
)
|
||||
segment_results.append({
|
||||
"type": "speech",
|
||||
"path": str(segment_output_path),
|
||||
"speaker_id": speaker_id,
|
||||
"text_chunk": text_chunk
|
||||
})
|
||||
processing_log.append(f"Successfully generated segment: {segment_output_path}")
|
||||
except Exception as e:
|
||||
error_message = f"Error generating speech for chunk '{text_chunk[:50]}...': {repr(e)}"
|
||||
processing_log.append(error_message)
|
||||
segment_results.append({"type": "error", "message": error_message, "text_chunk": text_chunk})
|
||||
filename_base = f"{output_base_name}_seg{segment_idx}_spk{speaker_id}_chunk{chunk_idx}"
|
||||
processing_log.append(f"Queueing TTS for chunk: '{text_chunk[:50]}...' using speaker '{speaker_id}'")
|
||||
planned = {
|
||||
"segment_idx": segment_idx,
|
||||
"speaker_id": speaker_id,
|
||||
"text_chunk": text_chunk,
|
||||
"abs_speaker_sample_path": abs_speaker_sample_path,
|
||||
"filename_base": filename_base,
|
||||
"params": {
|
||||
'exaggeration': item.get('exaggeration', 0.5),
|
||||
'cfg_weight': item.get('cfg_weight', 0.5),
|
||||
'temperature': item.get('temperature', 0.8),
|
||||
},
|
||||
}
|
||||
tasks.append(asyncio.create_task(run_one(planned)))
|
||||
segment_idx += 1
|
||||
|
||||
|
||||
elif item_type == "silence":
|
||||
duration = item.get("duration")
|
||||
if duration is None or duration < 0:
|
||||
processing_log.append(f"Skipping silence item {i+1} due to invalid duration.")
|
||||
segment_results.append({"type": "error", "message": "Invalid duration for silence"})
|
||||
results_map[segment_idx] = {"type": "error", "message": "Invalid duration for silence"}
|
||||
segment_idx += 1
|
||||
continue
|
||||
segment_results.append({"type": "silence", "duration": float(duration)})
|
||||
results_map[segment_idx] = {"type": "silence", "duration": float(duration)}
|
||||
processing_log.append(f"Added silence of {duration}s.")
|
||||
|
||||
segment_idx += 1
|
||||
|
||||
else:
|
||||
processing_log.append(f"Unknown item type '{item_type}' at item {i+1}. Skipping.")
|
||||
segment_results.append({"type": "error", "message": f"Unknown item type: {item_type}"})
|
||||
results_map[segment_idx] = {"type": "error", "message": f"Unknown item type: {item_type}"}
|
||||
segment_idx += 1
|
||||
|
||||
# Await all TTS tasks and merge results
|
||||
if tasks:
|
||||
processing_log.append(
|
||||
f"Dispatching {len(tasks)} TTS task(s) with concurrency limit "
|
||||
f"{getattr(config, 'TTS_MAX_CONCURRENCY', 2)}"
|
||||
)
|
||||
completed = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
for idx, payload, maybe_log in completed:
|
||||
results_map[idx] = payload
|
||||
if maybe_log:
|
||||
processing_log.append(maybe_log)
|
||||
|
||||
# Build ordered list
|
||||
for idx in sorted(results_map.keys()):
|
||||
segment_results.append(results_map[idx])
|
||||
|
||||
# Log the full segment_results list for debugging
|
||||
processing_log.append("[DEBUG] Final segment_results list:")
|
||||
|
@ -220,7 +277,7 @@ class DialogProcessorService:
|
|||
return {
|
||||
"log": "\n".join(processing_log),
|
||||
"segment_files": segment_results,
|
||||
"temp_dir": str(dialog_temp_dir) # For cleanup or zipping later
|
||||
"temp_dir": str(dialog_temp_dir)
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from pathlib import Path
|
||||
import gc # Garbage collector for memory management
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
# Import configuration
|
||||
try:
|
||||
|
@ -114,42 +117,52 @@ class TTSService:
|
|||
# output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info)
|
||||
output_file_path = target_output_dir / f"{output_filename_base}.wav"
|
||||
|
||||
print(f"Generating audio for text: \"{text[:50]}...\" with speaker sample: {speaker_sample_path}")
|
||||
wav = None
|
||||
start_ts = datetime.now()
|
||||
print(f"[{start_ts.isoformat(timespec='seconds')}] [TTS] START generate+save base={output_filename_base} len={len(text)} sample={speaker_sample_path}")
|
||||
try:
|
||||
with torch.no_grad(): # Important for inference
|
||||
wav = self.model.generate(
|
||||
text=text,
|
||||
audio_prompt_path=str(speaker_sample_p), # Must be a string path
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
torchaudio.save(str(output_file_path), wav, self.model.sr)
|
||||
print(f"Audio saved to: {output_file_path}")
|
||||
return output_file_path
|
||||
except Exception as e:
|
||||
print(f"Error during TTS generation or saving: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Explicitly delete the wav tensor to free memory
|
||||
if wav is not None:
|
||||
del wav
|
||||
|
||||
# Force garbage collection and cache cleanup
|
||||
gc.collect()
|
||||
if self.device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
elif self.device == "mps":
|
||||
if hasattr(torch.mps, "empty_cache"):
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# Unload the model if requested
|
||||
def _gen_and_save() -> Path:
|
||||
t0 = time.perf_counter()
|
||||
wav = None
|
||||
try:
|
||||
with torch.no_grad(): # Important for inference
|
||||
wav = self.model.generate(
|
||||
text=text,
|
||||
audio_prompt_path=str(speaker_sample_p), # Must be a string path
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Save the audio synchronously in the same thread
|
||||
torchaudio.save(str(output_file_path), wav, self.model.sr)
|
||||
t1 = time.perf_counter()
|
||||
print(f"[TTS-THREAD] Saved {output_file_path.name} in {t1 - t0:.2f}s")
|
||||
return output_file_path
|
||||
finally:
|
||||
# Cleanup in the same thread that created the tensor
|
||||
if wav is not None:
|
||||
del wav
|
||||
gc.collect()
|
||||
if self.device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
elif self.device == "mps":
|
||||
if hasattr(torch.mps, "empty_cache"):
|
||||
torch.mps.empty_cache()
|
||||
|
||||
out_path = await asyncio.to_thread(_gen_and_save)
|
||||
end_ts = datetime.now()
|
||||
print(f"[{end_ts.isoformat(timespec='seconds')}] [TTS] END generate+save base={output_filename_base} dur={(end_ts - start_ts).total_seconds():.2f}s -> {out_path}")
|
||||
|
||||
# Optionally unload model after generation
|
||||
if unload_after:
|
||||
print("Unloading TTS model after generation...")
|
||||
self.unload_model()
|
||||
|
||||
return out_path
|
||||
except Exception as e:
|
||||
print(f"Error during TTS generation or saving: {e}")
|
||||
raise
|
||||
|
||||
# Example usage (for testing, not part of the service itself)
|
||||
if __name__ == "__main__":
|
||||
async def main_test():
|
||||
|
|
|
@ -11,6 +11,29 @@
|
|||
<div class="container">
|
||||
<h1>Chatterbox TTS</h1>
|
||||
</div>
|
||||
|
||||
<!-- Paste Script Modal -->
|
||||
<div id="paste-script-modal" class="modal" style="display: none;">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3>Paste Dialog Script</h3>
|
||||
<button class="modal-close" id="paste-script-close">×</button>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<p>Paste JSONL content (one JSON object per line). Example lines:</p>
|
||||
<pre style="white-space:pre-wrap; background:#f6f8fa; padding:8px; border-radius:4px;">
|
||||
{"type":"speech","speaker_id":"alice","text":"Hello there!"}
|
||||
{"type":"silence","duration":0.5}
|
||||
{"type":"speech","speaker_id":"bob","text":"Hi!"}
|
||||
</pre>
|
||||
<textarea id="paste-script-text" rows="10" style="width:100%;" placeholder='Paste JSONL here'></textarea>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button id="paste-script-load" class="btn-primary">Load</button>
|
||||
<button id="paste-script-cancel" class="btn-secondary">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Global inline notification area -->
|
||||
|
@ -55,6 +78,7 @@
|
|||
<button id="save-script-btn">Save Script</button>
|
||||
<input type="file" id="load-script-input" accept=".jsonl" style="display: none;">
|
||||
<button id="load-script-btn">Load Script</button>
|
||||
<button id="paste-script-btn">Paste Script</button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
|
@ -108,8 +132,8 @@
|
|||
</div>
|
||||
</footer>
|
||||
|
||||
<!-- TTS Settings Modal -->
|
||||
<div id="tts-settings-modal" class="modal" style="display: none;">
|
||||
<!-- TTS Settings Modal -->
|
||||
<div id="tts-settings-modal" class="modal" style="display: none;">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3>TTS Settings</h3>
|
||||
|
|
|
@ -201,6 +201,12 @@ async function initializeDialogEditor() {
|
|||
const saveScriptBtn = document.getElementById('save-script-btn');
|
||||
const loadScriptBtn = document.getElementById('load-script-btn');
|
||||
const loadScriptInput = document.getElementById('load-script-input');
|
||||
const pasteScriptBtn = document.getElementById('paste-script-btn');
|
||||
const pasteModal = document.getElementById('paste-script-modal');
|
||||
const pasteText = document.getElementById('paste-script-text');
|
||||
const pasteLoadBtn = document.getElementById('paste-script-load');
|
||||
const pasteCancelBtn = document.getElementById('paste-script-cancel');
|
||||
const pasteCloseBtn = document.getElementById('paste-script-close');
|
||||
|
||||
// Results Display Elements
|
||||
const generationLogPre = document.getElementById('generation-log-content'); // Corrected ID
|
||||
|
@ -891,6 +897,71 @@ async function initializeDialogEditor() {
|
|||
reader.readAsText(file);
|
||||
}
|
||||
|
||||
// Load dialog script from pasted JSONL text
|
||||
async function loadDialogScriptFromText(text) {
|
||||
if (!text || !text.trim()) {
|
||||
showNotice('Please paste JSONL content to load.', 'warning', { timeout: 4000 });
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
const lines = text.trim().split('\n');
|
||||
const loadedItems = [];
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i].trim();
|
||||
if (!line) continue; // Skip empty lines
|
||||
try {
|
||||
const item = JSON.parse(line);
|
||||
const validatedItem = validateDialogItem(item, i + 1);
|
||||
if (validatedItem) {
|
||||
loadedItems.push(normalizeDialogItem(validatedItem));
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error(`Error parsing line ${i + 1}:`, parseError);
|
||||
showNotice(`Error parsing line ${i + 1}: ${parseError.message}`, 'error');
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (loadedItems.length === 0) {
|
||||
showNotice('No valid dialog items found in the pasted content.', 'warning', { timeout: 4000 });
|
||||
return false;
|
||||
}
|
||||
|
||||
// Confirm replacement if existing items
|
||||
if (dialogItems.length > 0) {
|
||||
const confirmed = await confirmAction(
|
||||
`This will replace your current dialog (${dialogItems.length} items) with the pasted script (${loadedItems.length} items). Continue?`
|
||||
);
|
||||
if (!confirmed) return false;
|
||||
}
|
||||
|
||||
// Ensure speakers are loaded before rendering
|
||||
if (availableSpeakersCache.length === 0) {
|
||||
try {
|
||||
availableSpeakersCache = await getSpeakers();
|
||||
} catch (error) {
|
||||
console.error('Error fetching speakers:', error);
|
||||
showNotice('Could not load speakers. Dialog loaded but speaker names may not display correctly.', 'warning', { timeout: 5000 });
|
||||
}
|
||||
}
|
||||
|
||||
// Replace current dialog
|
||||
dialogItems.splice(0, dialogItems.length, ...loadedItems);
|
||||
// Persist loaded script
|
||||
saveDialogToLocalStorage();
|
||||
renderDialogItems();
|
||||
|
||||
console.log(`Loaded ${loadedItems.length} dialog items from pasted text`);
|
||||
showNotice(`Successfully loaded ${loadedItems.length} dialog items.`, 'success', { timeout: 3000 });
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error('Error loading dialog script from text:', error);
|
||||
showNotice(`Error loading dialog script: ${error.message}`, 'error');
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function validateDialogItem(item, lineNumber) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
throw new Error(`Line ${lineNumber}: Invalid item format`);
|
||||
|
@ -946,14 +1017,41 @@ async function initializeDialogEditor() {
|
|||
const file = e.target.files[0];
|
||||
if (file) {
|
||||
loadDialogScript(file);
|
||||
// Reset input so same file can be loaded again
|
||||
e.target.value = '';
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// --- Paste Script (JSONL) Modal Handlers ---
|
||||
if (pasteScriptBtn && pasteModal && pasteText && pasteLoadBtn && pasteCancelBtn && pasteCloseBtn) {
|
||||
let escHandler = null;
|
||||
const closePasteModal = () => {
|
||||
pasteModal.style.display = 'none';
|
||||
pasteLoadBtn.onclick = null;
|
||||
pasteCancelBtn.onclick = null;
|
||||
pasteCloseBtn.onclick = null;
|
||||
pasteModal.onclick = null;
|
||||
if (escHandler) {
|
||||
document.removeEventListener('keydown', escHandler);
|
||||
escHandler = null;
|
||||
}
|
||||
};
|
||||
const openPasteModal = () => {
|
||||
pasteText.value = '';
|
||||
pasteModal.style.display = 'flex';
|
||||
escHandler = (e) => { if (e.key === 'Escape') closePasteModal(); };
|
||||
document.addEventListener('keydown', escHandler);
|
||||
pasteModal.onclick = (e) => { if (e.target === pasteModal) closePasteModal(); };
|
||||
pasteCloseBtn.onclick = closePasteModal;
|
||||
pasteCancelBtn.onclick = closePasteModal;
|
||||
pasteLoadBtn.onclick = async () => {
|
||||
const ok = await loadDialogScriptFromText(pasteText.value);
|
||||
if (ok) closePasteModal();
|
||||
};
|
||||
};
|
||||
pasteScriptBtn.addEventListener('click', openPasteModal);
|
||||
}
|
||||
|
||||
// --- Clear Dialog Button ---
|
||||
// Prefer an existing button with id if present; otherwise, create and insert beside Save/Load.
|
||||
let clearDialogBtn = document.getElementById('clear-dialog-btn');
|
||||
if (!clearDialogBtn) {
|
||||
clearDialogBtn = document.createElement('button');
|
||||
|
@ -970,21 +1068,23 @@ async function initializeDialogEditor() {
|
|||
}
|
||||
}
|
||||
|
||||
clearDialogBtn.addEventListener('click', async () => {
|
||||
if (dialogItems.length === 0) {
|
||||
showNotice('Dialog is already empty.', 'info', { timeout: 2500 });
|
||||
return;
|
||||
}
|
||||
const ok = await confirmAction(`This will remove ${dialogItems.length} dialog item(s). Continue?`);
|
||||
if (!ok) return;
|
||||
// Clear any transient input UI
|
||||
if (typeof clearTempInputArea === 'function') clearTempInputArea();
|
||||
// Clear state and persistence
|
||||
dialogItems.splice(0, dialogItems.length);
|
||||
try { localStorage.removeItem(LS_KEY); } catch (e) { /* ignore */ }
|
||||
renderDialogItems();
|
||||
showNotice('Dialog cleared.', 'success', { timeout: 2500 });
|
||||
});
|
||||
if (clearDialogBtn) {
|
||||
clearDialogBtn.addEventListener('click', async () => {
|
||||
if (dialogItems.length === 0) {
|
||||
showNotice('Dialog is already empty.', 'info', { timeout: 2500 });
|
||||
return;
|
||||
}
|
||||
const ok = await confirmAction(`This will remove ${dialogItems.length} dialog item(s). Continue?`);
|
||||
if (!ok) return;
|
||||
// Clear any transient input UI
|
||||
if (typeof clearTempInputArea === 'function') clearTempInputArea();
|
||||
// Clear state and persistence
|
||||
dialogItems.splice(0, dialogItems.length);
|
||||
try { localStorage.removeItem(LS_KEY); } catch (e) { /* ignore */ }
|
||||
renderDialogItems();
|
||||
showNotice('Dialog cleared.', 'success', { timeout: 2500 });
|
||||
});
|
||||
}
|
||||
|
||||
console.log('Dialog Editor Initialized');
|
||||
renderDialogItems(); // Initial render (empty)
|
||||
|
|
Loading…
Reference in New Issue