Major update: Enhanced memory management, configurable silence gaps, and file organization

- Added aggressive memory cleanup for both single and dialog generation
- Implemented configurable silence gaps between dialog lines
- Organized output files into single_output/ and dialog_output/ directories
- Fixed torch import scoping issues
- Updated README with comprehensive documentation
This commit is contained in:
Steve White 2025-06-04 12:37:52 -05:00
parent 869914e8a0
commit 769daab4c7
3 changed files with 264 additions and 47 deletions

73
README.md Normal file
View File

@ -0,0 +1,73 @@
# Chatterbox TTS Gradio App
This Gradio application provides a user interface for text-to-speech generation using the Chatterbox TTS model. It supports both single utterance generation and multi-speaker dialog generation with configurable silence gaps.
## Features
- **Single Utterance Generation**: Generate speech from text using a selected speaker
- **Dialog Generation**: Create multi-speaker conversations with configurable silence gaps
- **Speaker Management**: Add/remove speakers with custom audio samples
- **Memory Optimization**: Automatic model cleanup after generation
- **Output Organization**: Files saved in `single_output/` and `dialog_output/` directories
## Getting Started
1. Clone the repository:
```bash
git clone https://github.com/your-username/chatterbox-test.git
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Prepare speaker samples:
- Create a `speaker_samples/` directory
- Add audio samples (WAV format) for each speaker
- Update `speakers.yaml` with speaker names and file paths
4. Run the app:
```bash
python gradio_app.py
```
## Usage
### Single Utterance Tab
- Select a speaker from the dropdown
- Enter text to synthesize
- Adjust generation parameters as needed
- Click "Generate Speech"
### Dialog Generation Tab
1. Add speakers using the speaker configuration section
2. Enter dialog in the format:
```
Speaker1: "Hello, how are you?"
Speaker2: "I'm doing well!"
Silence: 0.5
Speaker1: "What are your plans for today?"
```
3. Set output base name
4. Click "Generate Dialog"
## File Organization
- Generated single utterances are saved to `single_output/`
- Dialog generation files are saved to `dialog_output/`
- Concatenated dialog files have `_concatenated.wav` suffix
- All files are zipped together for download
## Memory Management
The app automatically:
- Cleans up the TTS model after each generation
- Frees GPU memory (for CUDA/MPS devices)
- Deletes intermediate tensors to minimize memory footprint
## Troubleshooting
- **"Skipping unknown speaker"**: Add the speaker first using the speaker configuration
- **"Sample file not found"**: Verify the audio file exists in `speaker_samples/`
- **Memory issues**: Try enabling "Re-initialize model each line" for long dialogs

View File

@ -57,13 +57,19 @@ def split_text_at_sentence_boundaries(text, max_length=300):
return chunks return chunks
def parse_dialog_line(line): def parse_dialog_line(line):
"""Parse a dialog line in the format: Name: "Text""" """Parse a dialog line to extract speaker and text."""
pattern = r'^([^:]+):\s*"([^"]+)"$' # Handle silence lines
match = re.match(pattern, line.strip()) silence_match = re.match(r'^Silence:\s*([0-9]*\.?[0-9]+)$', line.strip(), re.IGNORECASE)
if silence_match:
duration = float(silence_match.group(1))
return "SILENCE", duration
# Handle regular dialog lines
match = re.match(r'^([^:]+):\s*["\'](.+)["\']$', line.strip())
if match: if match:
speaker = match.group(1).strip() speaker = match.group(1).strip()
text = match.group(2).strip() text = match.group(2).strip()
return (speaker, text) return speaker, text
return None return None
def save_speakers_config(speakers_dict: Dict) -> str: def save_speakers_config(speakers_dict: Dict) -> str:
@ -123,34 +129,80 @@ def update_speakers_dropdown():
return gr.Dropdown.update(choices=list(speakers.keys())) return gr.Dropdown.update(choices=list(speakers.keys()))
def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens): def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens):
# Get sample path from selection or upload import torch # Ensure torch is available in function scope
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample tts = None
wav = None
try:
# Get sample path from selection or upload
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample
if not os.path.exists(sample_path):
raise gr.Error("Sample file not found!")
# Load model (cached automatically by Gradio)
tts = ChatterboxTTS.from_pretrained(device="mps")
# Generate audio with advanced controls - disable gradients for inference
with torch.no_grad():
gen_kwargs = dict(
text=text,
audio_prompt_path=sample_path,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
temperature=temperature
)
# max_new_tokens is not supported by the current TTS library, so we ignore it here
wav = tts.generate(**gen_kwargs)
# Create output directory for single utterances
output_dir = "single_output"
os.makedirs(output_dir, exist_ok=True)
# Save with timestamp in the subdirectory
output_path = os.path.join(output_dir, f"output_{int(time.time())}.wav")
ta.save(output_path, wav, tts.sr)
return output_path, output_path
if not os.path.exists(sample_path): finally:
raise gr.Error("Sample file not found!") # Clean up audio tensor first
if wav is not None:
# Load model (cached automatically by Gradio) try:
tts = ChatterboxTTS.from_pretrained(device="mps") del wav
except:
# Generate audio with advanced controls pass
gen_kwargs = dict(
text=text, # Clean up model to free memory
audio_prompt_path=sample_path, if tts is not None:
exaggeration=exaggeration, print("Cleaning up TTS model to free memory...") # Debug log
cfg_weight=cfg_weight, try:
temperature=temperature # Move model to CPU and delete reference
) if hasattr(tts, 'cpu'):
# max_new_tokens is not supported by the current TTS library, so we ignore it here tts.cpu()
wav = tts.generate(**gen_kwargs) del tts
# Save with timestamp # Force garbage collection
output_path = f"output_{int(time.time())}.wav" import gc
ta.save(output_path, wav, tts.sr) gc.collect()
return output_path, output_path # Clear GPU cache - both CUDA and MPS
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Clear MPS cache for Apple Silicon
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
except:
pass
print("Model cleanup completed") # Debug log
except Exception as cleanup_error:
print(f"Error during model cleanup: {cleanup_error}") # Debug log
def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()): def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()):
"""Process dialog text and generate audio files.""" """Process dialog text and generate audio files."""
import torch # Ensure torch is available in function scope
model = None
try: try:
print("Starting dialog processing...") # Debug log print("Starting dialog processing...") # Debug log
print(f"Speaker samples: {speaker_samples}") # Debug log print(f"Speaker samples: {speaker_samples}") # Debug log
@ -167,7 +219,6 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
print(f"Processing {len(dialog_lines)} dialog lines") # Debug log print(f"Processing {len(dialog_lines)} dialog lines") # Debug log
# Initialize model only once if not reinitializing per line # Initialize model only once if not reinitializing per line
model = None
if not reinit_each_line: if not reinit_each_line:
progress(0.1, desc="Loading TTS model...") progress(0.1, desc="Loading TTS model...")
try: try:
@ -183,6 +234,7 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
# Process each dialog line # Process each dialog line
file_counter = 1 file_counter = 1
output_files = [] output_files = []
silence_durations = [] # Track silence durations for concatenation
summary = [] summary = []
for i, line in enumerate(dialog_lines): for i, line in enumerate(dialog_lines):
@ -195,13 +247,25 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
print(f"Skipping line (invalid format): {line}") # Debug log print(f"Skipping line (invalid format): {line}") # Debug log
continue continue
speaker, text = parsed speaker, text_or_duration = parsed
print(f"Found speaker: {speaker}, text: {text[:50]}...") # Debug log print(f"Found speaker: {speaker}, content: {text_or_duration}") # Debug log
# Handle silence entries
if speaker == "SILENCE":
duration = text_or_duration
print(f"Adding silence duration: {duration} seconds") # Debug log
silence_durations.append(duration)
summary.append(f"Silence: {duration} seconds")
continue
# Handle regular dialog
text = text_or_duration
if speaker not in speaker_samples: if speaker not in speaker_samples:
msg = f"Skipping unknown speaker: {speaker}" msg = f"Skipping unknown speaker: {speaker}"
print(msg) # Debug log print(msg) # Debug log
summary.append(msg) summary.append(msg)
# Add default silence duration if we skip a speaker
silence_durations.append(1.0)
continue continue
sample_path = speaker_samples[speaker] sample_path = speaker_samples[speaker]
@ -209,6 +273,8 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
msg = f"Audio sample not found for speaker '{speaker}': {sample_path}" msg = f"Audio sample not found for speaker '{speaker}': {sample_path}"
print(msg) # Debug log print(msg) # Debug log
summary.append(msg) summary.append(msg)
# Add default silence duration if we skip due to missing sample
silence_durations.append(1.0)
continue continue
if reinit_each_line or model is None: if reinit_each_line or model is None:
@ -222,8 +288,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log
try: try:
wav = model.generate(chunk, audio_prompt_path=sample_path) with torch.no_grad():
wav = model.generate(chunk, audio_prompt_path=sample_path)
ta.save(output_file, wav, model.sr) ta.save(output_file, wav, model.sr)
# Clean up wav tensor immediately
del wav
output_files.append(output_file) output_files.append(output_file)
summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...") summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...")
file_counter += 1 file_counter += 1
@ -237,8 +306,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
print(f"Generating audio: {text[:50]}...") # Debug log print(f"Generating audio: {text[:50]}...") # Debug log
try: try:
wav = model.generate(text, audio_prompt_path=sample_path) with torch.no_grad():
wav = model.generate(text, audio_prompt_path=sample_path)
ta.save(output_file, wav, model.sr) ta.save(output_file, wav, model.sr)
# Clean up wav tensor immediately
del wav
output_files.append(output_file) output_files.append(output_file)
summary.append(f"{output_file}: {speaker} - {text[:50]}...") summary.append(f"{output_file}: {speaker} - {text[:50]}...")
file_counter += 1 file_counter += 1
@ -258,13 +330,14 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
if not output_files: if not output_files:
return "Error: No audio files were generated. Check speaker names and audio samples.", None return "Error: No audio files were generated. Check speaker names and audio samples.", None
# Concatenate all audio files with 1-second gaps # Concatenate all audio files with configurable gaps
concatenated_file = None concatenated_file = None
waveforms = []
concatenated = None
try: try:
if len(output_files) > 1: if len(output_files) > 1:
print("Concatenating audio files...") # Debug log print("Concatenating audio files...") # Debug log
# Load all audio files # Load all audio files
waveforms = []
sample_rates = set() sample_rates = set()
for file in output_files: for file in output_files:
@ -276,13 +349,26 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
raise ValueError(f"Sample rate mismatch: {sample_rates}") raise ValueError(f"Sample rate mismatch: {sample_rates}")
sample_rate = sample_rates.pop() sample_rate = sample_rates.pop()
gap_samples = int(1.0 * sample_rate) # 1 second gap
gap = torch.zeros(1, gap_samples) # Mono channel
# Concatenate waveforms with gaps # Start with the first audio segment
concatenated = waveforms[0] concatenated = waveforms[0]
for wav in waveforms[1:]:
# Add subsequent segments with appropriate silence gaps
for i, wav in enumerate(waveforms[1:], 1):
# Determine silence duration for this gap
if i-1 < len(silence_durations):
gap_duration = silence_durations[i-1]
else:
gap_duration = 1.0 # Default 1 second
gap_samples = int(gap_duration * sample_rate)
gap = torch.zeros(1, gap_samples) # Mono channel
print(f"Adding {gap_duration}s gap before segment {i+1}") # Debug log
concatenated = torch.cat([concatenated, gap, wav], dim=1) concatenated = torch.cat([concatenated, gap, wav], dim=1)
# Clean up gap tensor immediately
del gap
# Save concatenated file # Save concatenated file
concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav") concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav")
@ -309,13 +395,57 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
error_msg = f"Error creating zip file: {str(e)}" error_msg = f"Error creating zip file: {str(e)}"
print(error_msg) # Debug log print(error_msg) # Debug log
return "\n".join(summary + [error_msg]), None return "\n".join(summary + [error_msg]), None
finally:
# Clean up concatenation tensors
try:
if waveforms:
for waveform in waveforms:
del waveform
waveforms.clear()
if concatenated is not None:
del concatenated
# Force garbage collection
import gc
gc.collect()
print("Concatenation tensors cleaned up") # Debug log
except Exception as cleanup_error:
print(f"Error during concatenation cleanup: {cleanup_error}") # Debug log
except Exception as e: except Exception as e:
error_msg = f"Unexpected error: {str(e)}" error_msg = f"Unexpected error: {str(e)}"
print(error_msg) # Debug log print(error_msg) # Debug log
import traceback import traceback
traceback.print_exc() # Print full traceback traceback.print_exc() # Print full traceback
return error_msg, None return error_msg, None
finally:
# Clean up model to free memory
if 'model' in locals() and model is not None:
print("Cleaning up TTS model to free memory...") # Debug log
try:
# Move model to CPU and delete reference
if hasattr(model, 'cpu'):
model.cpu()
del model
# Force garbage collection
import gc
gc.collect()
# Clear GPU cache - both CUDA and MPS
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Clear MPS cache for Apple Silicon
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
except:
pass
print("Model cleanup completed") # Debug log
except Exception as cleanup_error:
print(f"Error during model cleanup: {cleanup_error}") # Debug log
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("# Chatterbox TTS Generator") gr.Markdown("# Chatterbox TTS Generator")
@ -461,13 +591,23 @@ Speaker2: "I'm working on a new project."''',
add_speaker_btn.click( add_speaker_btn.click(
fn=add_speaker, fn=add_speaker,
inputs=[new_speaker_name, new_speaker_audio, speakers_state], inputs=[new_speaker_name, new_speaker_audio, speakers_state],
outputs=[config_status, speakers_dropdown] outputs=[config_status, speakers_state]
) )
remove_speaker_btn.click( remove_speaker_btn.click(
fn=remove_speaker, fn=remove_speaker,
inputs=[speakers_dropdown, speakers_state], inputs=[speakers_dropdown, speakers_state],
outputs=[config_status, speakers_dropdown] outputs=[config_status, speakers_state]
)
# Update the speakers dropdown when speakers_state changes
def update_dropdown(speakers_dict):
return gr.Dropdown(choices=list(speakers_dict.keys()) if speakers_dict else [])
speakers_state.change(
fn=update_dropdown,
inputs=[speakers_state],
outputs=[speakers_dropdown]
) )
# Update the speakers dropdown when the tab is selected # Update the speakers dropdown when the tab is selected

View File

@ -1,7 +1,11 @@
Tara: Tara.mp3
Zac: Zac.mp3
Leah: Leah.mp3
Leo: Leo.mp3
Adam: Adam.mp3 Adam: Adam.mp3
Alice: Alice.mp3 Alice: Alice.mp3
Lewis: Lewis.mp3 David: speaker_samples/david.mp3
Debbie: speaker_samples/debbie.mp3
Denise: speaker_samples/denise.mp3
Leah: Leah.mp3
Leo: Leo.mp3
Lewis: Lewis.mp3
Maria: speaker_samples/maria.mp3
Tara: Tara.mp3
Zac: Zac.mp3