Major update: Enhanced memory management, configurable silence gaps, and file organization
- Added aggressive memory cleanup for both single and dialog generation - Implemented configurable silence gaps between dialog lines - Organized output files into single_output/ and dialog_output/ directories - Fixed torch import scoping issues - Updated README with comprehensive documentation
This commit is contained in:
parent
869914e8a0
commit
769daab4c7
|
@ -0,0 +1,73 @@
|
|||
# Chatterbox TTS Gradio App
|
||||
|
||||
This Gradio application provides a user interface for text-to-speech generation using the Chatterbox TTS model. It supports both single utterance generation and multi-speaker dialog generation with configurable silence gaps.
|
||||
|
||||
## Features
|
||||
|
||||
- **Single Utterance Generation**: Generate speech from text using a selected speaker
|
||||
- **Dialog Generation**: Create multi-speaker conversations with configurable silence gaps
|
||||
- **Speaker Management**: Add/remove speakers with custom audio samples
|
||||
- **Memory Optimization**: Automatic model cleanup after generation
|
||||
- **Output Organization**: Files saved in `single_output/` and `dialog_output/` directories
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/your-username/chatterbox-test.git
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Prepare speaker samples:
|
||||
- Create a `speaker_samples/` directory
|
||||
- Add audio samples (WAV format) for each speaker
|
||||
- Update `speakers.yaml` with speaker names and file paths
|
||||
|
||||
4. Run the app:
|
||||
```bash
|
||||
python gradio_app.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Single Utterance Tab
|
||||
- Select a speaker from the dropdown
|
||||
- Enter text to synthesize
|
||||
- Adjust generation parameters as needed
|
||||
- Click "Generate Speech"
|
||||
|
||||
### Dialog Generation Tab
|
||||
1. Add speakers using the speaker configuration section
|
||||
2. Enter dialog in the format:
|
||||
```
|
||||
Speaker1: "Hello, how are you?"
|
||||
Speaker2: "I'm doing well!"
|
||||
Silence: 0.5
|
||||
Speaker1: "What are your plans for today?"
|
||||
```
|
||||
3. Set output base name
|
||||
4. Click "Generate Dialog"
|
||||
|
||||
## File Organization
|
||||
|
||||
- Generated single utterances are saved to `single_output/`
|
||||
- Dialog generation files are saved to `dialog_output/`
|
||||
- Concatenated dialog files have `_concatenated.wav` suffix
|
||||
- All files are zipped together for download
|
||||
|
||||
## Memory Management
|
||||
|
||||
The app automatically:
|
||||
- Cleans up the TTS model after each generation
|
||||
- Frees GPU memory (for CUDA/MPS devices)
|
||||
- Deletes intermediate tensors to minimize memory footprint
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **"Skipping unknown speaker"**: Add the speaker first using the speaker configuration
|
||||
- **"Sample file not found"**: Verify the audio file exists in `speaker_samples/`
|
||||
- **Memory issues**: Try enabling "Re-initialize model each line" for long dialogs
|
176
gradio_app.py
176
gradio_app.py
|
@ -57,13 +57,19 @@ def split_text_at_sentence_boundaries(text, max_length=300):
|
|||
return chunks
|
||||
|
||||
def parse_dialog_line(line):
|
||||
"""Parse a dialog line in the format: Name: "Text"""
|
||||
pattern = r'^([^:]+):\s*"([^"]+)"$'
|
||||
match = re.match(pattern, line.strip())
|
||||
"""Parse a dialog line to extract speaker and text."""
|
||||
# Handle silence lines
|
||||
silence_match = re.match(r'^Silence:\s*([0-9]*\.?[0-9]+)$', line.strip(), re.IGNORECASE)
|
||||
if silence_match:
|
||||
duration = float(silence_match.group(1))
|
||||
return "SILENCE", duration
|
||||
|
||||
# Handle regular dialog lines
|
||||
match = re.match(r'^([^:]+):\s*["\'](.+)["\']$', line.strip())
|
||||
if match:
|
||||
speaker = match.group(1).strip()
|
||||
text = match.group(2).strip()
|
||||
return (speaker, text)
|
||||
return speaker, text
|
||||
return None
|
||||
|
||||
def save_speakers_config(speakers_dict: Dict) -> str:
|
||||
|
@ -123,6 +129,10 @@ def update_speakers_dropdown():
|
|||
return gr.Dropdown.update(choices=list(speakers.keys()))
|
||||
|
||||
def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens):
|
||||
import torch # Ensure torch is available in function scope
|
||||
tts = None
|
||||
wav = None
|
||||
try:
|
||||
# Get sample path from selection or upload
|
||||
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample
|
||||
|
||||
|
@ -132,7 +142,8 @@ def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight
|
|||
# Load model (cached automatically by Gradio)
|
||||
tts = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
# Generate audio with advanced controls
|
||||
# Generate audio with advanced controls - disable gradients for inference
|
||||
with torch.no_grad():
|
||||
gen_kwargs = dict(
|
||||
text=text,
|
||||
audio_prompt_path=sample_path,
|
||||
|
@ -143,14 +154,55 @@ def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight
|
|||
# max_new_tokens is not supported by the current TTS library, so we ignore it here
|
||||
wav = tts.generate(**gen_kwargs)
|
||||
|
||||
# Save with timestamp
|
||||
output_path = f"output_{int(time.time())}.wav"
|
||||
# Create output directory for single utterances
|
||||
output_dir = "single_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save with timestamp in the subdirectory
|
||||
output_path = os.path.join(output_dir, f"output_{int(time.time())}.wav")
|
||||
ta.save(output_path, wav, tts.sr)
|
||||
|
||||
return output_path, output_path
|
||||
|
||||
finally:
|
||||
# Clean up audio tensor first
|
||||
if wav is not None:
|
||||
try:
|
||||
del wav
|
||||
except:
|
||||
pass
|
||||
|
||||
# Clean up model to free memory
|
||||
if tts is not None:
|
||||
print("Cleaning up TTS model to free memory...") # Debug log
|
||||
try:
|
||||
# Move model to CPU and delete reference
|
||||
if hasattr(tts, 'cpu'):
|
||||
tts.cpu()
|
||||
del tts
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Clear GPU cache - both CUDA and MPS
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Clear MPS cache for Apple Silicon
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
print("Model cleanup completed") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during model cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()):
|
||||
"""Process dialog text and generate audio files."""
|
||||
import torch # Ensure torch is available in function scope
|
||||
model = None
|
||||
try:
|
||||
print("Starting dialog processing...") # Debug log
|
||||
print(f"Speaker samples: {speaker_samples}") # Debug log
|
||||
|
@ -167,7 +219,6 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
print(f"Processing {len(dialog_lines)} dialog lines") # Debug log
|
||||
|
||||
# Initialize model only once if not reinitializing per line
|
||||
model = None
|
||||
if not reinit_each_line:
|
||||
progress(0.1, desc="Loading TTS model...")
|
||||
try:
|
||||
|
@ -183,6 +234,7 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
# Process each dialog line
|
||||
file_counter = 1
|
||||
output_files = []
|
||||
silence_durations = [] # Track silence durations for concatenation
|
||||
summary = []
|
||||
|
||||
for i, line in enumerate(dialog_lines):
|
||||
|
@ -195,13 +247,25 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
print(f"Skipping line (invalid format): {line}") # Debug log
|
||||
continue
|
||||
|
||||
speaker, text = parsed
|
||||
print(f"Found speaker: {speaker}, text: {text[:50]}...") # Debug log
|
||||
speaker, text_or_duration = parsed
|
||||
print(f"Found speaker: {speaker}, content: {text_or_duration}") # Debug log
|
||||
|
||||
# Handle silence entries
|
||||
if speaker == "SILENCE":
|
||||
duration = text_or_duration
|
||||
print(f"Adding silence duration: {duration} seconds") # Debug log
|
||||
silence_durations.append(duration)
|
||||
summary.append(f"Silence: {duration} seconds")
|
||||
continue
|
||||
|
||||
# Handle regular dialog
|
||||
text = text_or_duration
|
||||
if speaker not in speaker_samples:
|
||||
msg = f"Skipping unknown speaker: {speaker}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
# Add default silence duration if we skip a speaker
|
||||
silence_durations.append(1.0)
|
||||
continue
|
||||
|
||||
sample_path = speaker_samples[speaker]
|
||||
|
@ -209,6 +273,8 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
msg = f"Audio sample not found for speaker '{speaker}': {sample_path}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
# Add default silence duration if we skip due to missing sample
|
||||
silence_durations.append(1.0)
|
||||
continue
|
||||
|
||||
if reinit_each_line or model is None:
|
||||
|
@ -222,8 +288,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log
|
||||
try:
|
||||
with torch.no_grad():
|
||||
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
# Clean up wav tensor immediately
|
||||
del wav
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
||||
file_counter += 1
|
||||
|
@ -237,8 +306,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio: {text[:50]}...") # Debug log
|
||||
try:
|
||||
with torch.no_grad():
|
||||
wav = model.generate(text, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
# Clean up wav tensor immediately
|
||||
del wav
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} - {text[:50]}...")
|
||||
file_counter += 1
|
||||
|
@ -258,13 +330,14 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
if not output_files:
|
||||
return "Error: No audio files were generated. Check speaker names and audio samples.", None
|
||||
|
||||
# Concatenate all audio files with 1-second gaps
|
||||
# Concatenate all audio files with configurable gaps
|
||||
concatenated_file = None
|
||||
waveforms = []
|
||||
concatenated = None
|
||||
try:
|
||||
if len(output_files) > 1:
|
||||
print("Concatenating audio files...") # Debug log
|
||||
# Load all audio files
|
||||
waveforms = []
|
||||
sample_rates = set()
|
||||
|
||||
for file in output_files:
|
||||
|
@ -276,14 +349,27 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
raise ValueError(f"Sample rate mismatch: {sample_rates}")
|
||||
|
||||
sample_rate = sample_rates.pop()
|
||||
gap_samples = int(1.0 * sample_rate) # 1 second gap
|
||||
|
||||
# Start with the first audio segment
|
||||
concatenated = waveforms[0]
|
||||
|
||||
# Add subsequent segments with appropriate silence gaps
|
||||
for i, wav in enumerate(waveforms[1:], 1):
|
||||
# Determine silence duration for this gap
|
||||
if i-1 < len(silence_durations):
|
||||
gap_duration = silence_durations[i-1]
|
||||
else:
|
||||
gap_duration = 1.0 # Default 1 second
|
||||
|
||||
gap_samples = int(gap_duration * sample_rate)
|
||||
gap = torch.zeros(1, gap_samples) # Mono channel
|
||||
|
||||
# Concatenate waveforms with gaps
|
||||
concatenated = waveforms[0]
|
||||
for wav in waveforms[1:]:
|
||||
print(f"Adding {gap_duration}s gap before segment {i+1}") # Debug log
|
||||
concatenated = torch.cat([concatenated, gap, wav], dim=1)
|
||||
|
||||
# Clean up gap tensor immediately
|
||||
del gap
|
||||
|
||||
# Save concatenated file
|
||||
concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav")
|
||||
ta.save(concatenated_path, concatenated, sample_rate)
|
||||
|
@ -309,6 +395,21 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
error_msg = f"Error creating zip file: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
return "\n".join(summary + [error_msg]), None
|
||||
finally:
|
||||
# Clean up concatenation tensors
|
||||
try:
|
||||
if waveforms:
|
||||
for waveform in waveforms:
|
||||
del waveform
|
||||
waveforms.clear()
|
||||
if concatenated is not None:
|
||||
del concatenated
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("Concatenation tensors cleaned up") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during concatenation cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error: {str(e)}"
|
||||
|
@ -317,6 +418,35 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line,
|
|||
traceback.print_exc() # Print full traceback
|
||||
return error_msg, None
|
||||
|
||||
finally:
|
||||
# Clean up model to free memory
|
||||
if 'model' in locals() and model is not None:
|
||||
print("Cleaning up TTS model to free memory...") # Debug log
|
||||
try:
|
||||
# Move model to CPU and delete reference
|
||||
if hasattr(model, 'cpu'):
|
||||
model.cpu()
|
||||
del model
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Clear GPU cache - both CUDA and MPS
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Clear MPS cache for Apple Silicon
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
print("Model cleanup completed") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during model cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# Chatterbox TTS Generator")
|
||||
|
||||
|
@ -461,13 +591,23 @@ Speaker2: "I'm working on a new project."''',
|
|||
add_speaker_btn.click(
|
||||
fn=add_speaker,
|
||||
inputs=[new_speaker_name, new_speaker_audio, speakers_state],
|
||||
outputs=[config_status, speakers_dropdown]
|
||||
outputs=[config_status, speakers_state]
|
||||
)
|
||||
|
||||
remove_speaker_btn.click(
|
||||
fn=remove_speaker,
|
||||
inputs=[speakers_dropdown, speakers_state],
|
||||
outputs=[config_status, speakers_dropdown]
|
||||
outputs=[config_status, speakers_state]
|
||||
)
|
||||
|
||||
# Update the speakers dropdown when speakers_state changes
|
||||
def update_dropdown(speakers_dict):
|
||||
return gr.Dropdown(choices=list(speakers_dict.keys()) if speakers_dict else [])
|
||||
|
||||
speakers_state.change(
|
||||
fn=update_dropdown,
|
||||
inputs=[speakers_state],
|
||||
outputs=[speakers_dropdown]
|
||||
)
|
||||
|
||||
# Update the speakers dropdown when the tab is selected
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
Tara: Tara.mp3
|
||||
Zac: Zac.mp3
|
||||
Leah: Leah.mp3
|
||||
Leo: Leo.mp3
|
||||
Adam: Adam.mp3
|
||||
Alice: Alice.mp3
|
||||
David: speaker_samples/david.mp3
|
||||
Debbie: speaker_samples/debbie.mp3
|
||||
Denise: speaker_samples/denise.mp3
|
||||
Leah: Leah.mp3
|
||||
Leo: Leo.mp3
|
||||
Lewis: Lewis.mp3
|
||||
Maria: speaker_samples/maria.mp3
|
||||
Tara: Tara.mp3
|
||||
Zac: Zac.mp3
|
||||
|
|
Loading…
Reference in New Issue