diff --git a/README.md b/README.md new file mode 100644 index 0000000..6860fed --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# Chatterbox TTS Gradio App + +This Gradio application provides a user interface for text-to-speech generation using the Chatterbox TTS model. It supports both single utterance generation and multi-speaker dialog generation with configurable silence gaps. + +## Features + +- **Single Utterance Generation**: Generate speech from text using a selected speaker +- **Dialog Generation**: Create multi-speaker conversations with configurable silence gaps +- **Speaker Management**: Add/remove speakers with custom audio samples +- **Memory Optimization**: Automatic model cleanup after generation +- **Output Organization**: Files saved in `single_output/` and `dialog_output/` directories + +## Getting Started + +1. Clone the repository: + ```bash + git clone https://github.com/your-username/chatterbox-test.git + ``` + +2. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +3. Prepare speaker samples: + - Create a `speaker_samples/` directory + - Add audio samples (WAV format) for each speaker + - Update `speakers.yaml` with speaker names and file paths + +4. Run the app: + ```bash + python gradio_app.py + ``` + +## Usage + +### Single Utterance Tab +- Select a speaker from the dropdown +- Enter text to synthesize +- Adjust generation parameters as needed +- Click "Generate Speech" + +### Dialog Generation Tab +1. Add speakers using the speaker configuration section +2. Enter dialog in the format: + ``` + Speaker1: "Hello, how are you?" + Speaker2: "I'm doing well!" + Silence: 0.5 + Speaker1: "What are your plans for today?" + ``` +3. Set output base name +4. Click "Generate Dialog" + +## File Organization + +- Generated single utterances are saved to `single_output/` +- Dialog generation files are saved to `dialog_output/` +- Concatenated dialog files have `_concatenated.wav` suffix +- All files are zipped together for download + +## Memory Management + +The app automatically: +- Cleans up the TTS model after each generation +- Frees GPU memory (for CUDA/MPS devices) +- Deletes intermediate tensors to minimize memory footprint + +## Troubleshooting + +- **"Skipping unknown speaker"**: Add the speaker first using the speaker configuration +- **"Sample file not found"**: Verify the audio file exists in `speaker_samples/` +- **Memory issues**: Try enabling "Re-initialize model each line" for long dialogs diff --git a/gradio_app.py b/gradio_app.py index 95c11e1..9053428 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -57,13 +57,19 @@ def split_text_at_sentence_boundaries(text, max_length=300): return chunks def parse_dialog_line(line): - """Parse a dialog line in the format: Name: "Text""" - pattern = r'^([^:]+):\s*"([^"]+)"$' - match = re.match(pattern, line.strip()) + """Parse a dialog line to extract speaker and text.""" + # Handle silence lines + silence_match = re.match(r'^Silence:\s*([0-9]*\.?[0-9]+)$', line.strip(), re.IGNORECASE) + if silence_match: + duration = float(silence_match.group(1)) + return "SILENCE", duration + + # Handle regular dialog lines + match = re.match(r'^([^:]+):\s*["\'](.+)["\']$', line.strip()) if match: speaker = match.group(1).strip() text = match.group(2).strip() - return (speaker, text) + return speaker, text return None def save_speakers_config(speakers_dict: Dict) -> str: @@ -123,34 +129,80 @@ def update_speakers_dropdown(): return gr.Dropdown.update(choices=list(speakers.keys())) def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens): - # Get sample path from selection or upload - sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample + import torch # Ensure torch is available in function scope + tts = None + wav = None + try: + # Get sample path from selection or upload + sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample + + if not os.path.exists(sample_path): + raise gr.Error("Sample file not found!") + + # Load model (cached automatically by Gradio) + tts = ChatterboxTTS.from_pretrained(device="mps") + + # Generate audio with advanced controls - disable gradients for inference + with torch.no_grad(): + gen_kwargs = dict( + text=text, + audio_prompt_path=sample_path, + exaggeration=exaggeration, + cfg_weight=cfg_weight, + temperature=temperature + ) + # max_new_tokens is not supported by the current TTS library, so we ignore it here + wav = tts.generate(**gen_kwargs) + + # Create output directory for single utterances + output_dir = "single_output" + os.makedirs(output_dir, exist_ok=True) + + # Save with timestamp in the subdirectory + output_path = os.path.join(output_dir, f"output_{int(time.time())}.wav") + ta.save(output_path, wav, tts.sr) + + return output_path, output_path - if not os.path.exists(sample_path): - raise gr.Error("Sample file not found!") - - # Load model (cached automatically by Gradio) - tts = ChatterboxTTS.from_pretrained(device="mps") - - # Generate audio with advanced controls - gen_kwargs = dict( - text=text, - audio_prompt_path=sample_path, - exaggeration=exaggeration, - cfg_weight=cfg_weight, - temperature=temperature - ) - # max_new_tokens is not supported by the current TTS library, so we ignore it here - wav = tts.generate(**gen_kwargs) - - # Save with timestamp - output_path = f"output_{int(time.time())}.wav" - ta.save(output_path, wav, tts.sr) - - return output_path, output_path + finally: + # Clean up audio tensor first + if wav is not None: + try: + del wav + except: + pass + + # Clean up model to free memory + if tts is not None: + print("Cleaning up TTS model to free memory...") # Debug log + try: + # Move model to CPU and delete reference + if hasattr(tts, 'cpu'): + tts.cpu() + del tts + + # Force garbage collection + import gc + gc.collect() + + # Clear GPU cache - both CUDA and MPS + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Clear MPS cache for Apple Silicon + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + torch.mps.empty_cache() + except: + pass + + print("Model cleanup completed") # Debug log + except Exception as cleanup_error: + print(f"Error during model cleanup: {cleanup_error}") # Debug log def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()): """Process dialog text and generate audio files.""" + import torch # Ensure torch is available in function scope + model = None try: print("Starting dialog processing...") # Debug log print(f"Speaker samples: {speaker_samples}") # Debug log @@ -167,7 +219,6 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, print(f"Processing {len(dialog_lines)} dialog lines") # Debug log # Initialize model only once if not reinitializing per line - model = None if not reinit_each_line: progress(0.1, desc="Loading TTS model...") try: @@ -183,6 +234,7 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, # Process each dialog line file_counter = 1 output_files = [] + silence_durations = [] # Track silence durations for concatenation summary = [] for i, line in enumerate(dialog_lines): @@ -195,13 +247,25 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, print(f"Skipping line (invalid format): {line}") # Debug log continue - speaker, text = parsed - print(f"Found speaker: {speaker}, text: {text[:50]}...") # Debug log + speaker, text_or_duration = parsed + print(f"Found speaker: {speaker}, content: {text_or_duration}") # Debug log + # Handle silence entries + if speaker == "SILENCE": + duration = text_or_duration + print(f"Adding silence duration: {duration} seconds") # Debug log + silence_durations.append(duration) + summary.append(f"Silence: {duration} seconds") + continue + + # Handle regular dialog + text = text_or_duration if speaker not in speaker_samples: msg = f"Skipping unknown speaker: {speaker}" print(msg) # Debug log summary.append(msg) + # Add default silence duration if we skip a speaker + silence_durations.append(1.0) continue sample_path = speaker_samples[speaker] @@ -209,6 +273,8 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, msg = f"Audio sample not found for speaker '{speaker}': {sample_path}" print(msg) # Debug log summary.append(msg) + # Add default silence duration if we skip due to missing sample + silence_durations.append(1.0) continue if reinit_each_line or model is None: @@ -222,8 +288,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log try: - wav = model.generate(chunk, audio_prompt_path=sample_path) + with torch.no_grad(): + wav = model.generate(chunk, audio_prompt_path=sample_path) ta.save(output_file, wav, model.sr) + # Clean up wav tensor immediately + del wav output_files.append(output_file) summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...") file_counter += 1 @@ -237,8 +306,11 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") print(f"Generating audio: {text[:50]}...") # Debug log try: - wav = model.generate(text, audio_prompt_path=sample_path) + with torch.no_grad(): + wav = model.generate(text, audio_prompt_path=sample_path) ta.save(output_file, wav, model.sr) + # Clean up wav tensor immediately + del wav output_files.append(output_file) summary.append(f"{output_file}: {speaker} - {text[:50]}...") file_counter += 1 @@ -258,13 +330,14 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, if not output_files: return "Error: No audio files were generated. Check speaker names and audio samples.", None - # Concatenate all audio files with 1-second gaps + # Concatenate all audio files with configurable gaps concatenated_file = None + waveforms = [] + concatenated = None try: if len(output_files) > 1: print("Concatenating audio files...") # Debug log # Load all audio files - waveforms = [] sample_rates = set() for file in output_files: @@ -276,13 +349,26 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, raise ValueError(f"Sample rate mismatch: {sample_rates}") sample_rate = sample_rates.pop() - gap_samples = int(1.0 * sample_rate) # 1 second gap - gap = torch.zeros(1, gap_samples) # Mono channel - # Concatenate waveforms with gaps + # Start with the first audio segment concatenated = waveforms[0] - for wav in waveforms[1:]: + + # Add subsequent segments with appropriate silence gaps + for i, wav in enumerate(waveforms[1:], 1): + # Determine silence duration for this gap + if i-1 < len(silence_durations): + gap_duration = silence_durations[i-1] + else: + gap_duration = 1.0 # Default 1 second + + gap_samples = int(gap_duration * sample_rate) + gap = torch.zeros(1, gap_samples) # Mono channel + + print(f"Adding {gap_duration}s gap before segment {i+1}") # Debug log concatenated = torch.cat([concatenated, gap, wav], dim=1) + + # Clean up gap tensor immediately + del gap # Save concatenated file concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav") @@ -309,13 +395,57 @@ def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, error_msg = f"Error creating zip file: {str(e)}" print(error_msg) # Debug log return "\n".join(summary + [error_msg]), None - + finally: + # Clean up concatenation tensors + try: + if waveforms: + for waveform in waveforms: + del waveform + waveforms.clear() + if concatenated is not None: + del concatenated + # Force garbage collection + import gc + gc.collect() + print("Concatenation tensors cleaned up") # Debug log + except Exception as cleanup_error: + print(f"Error during concatenation cleanup: {cleanup_error}") # Debug log + except Exception as e: error_msg = f"Unexpected error: {str(e)}" print(error_msg) # Debug log import traceback traceback.print_exc() # Print full traceback return error_msg, None + + finally: + # Clean up model to free memory + if 'model' in locals() and model is not None: + print("Cleaning up TTS model to free memory...") # Debug log + try: + # Move model to CPU and delete reference + if hasattr(model, 'cpu'): + model.cpu() + del model + + # Force garbage collection + import gc + gc.collect() + + # Clear GPU cache - both CUDA and MPS + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Clear MPS cache for Apple Silicon + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + torch.mps.empty_cache() + except: + pass + + print("Model cleanup completed") # Debug log + except Exception as cleanup_error: + print(f"Error during model cleanup: {cleanup_error}") # Debug log with gr.Blocks() as demo: gr.Markdown("# Chatterbox TTS Generator") @@ -461,13 +591,23 @@ Speaker2: "I'm working on a new project."''', add_speaker_btn.click( fn=add_speaker, inputs=[new_speaker_name, new_speaker_audio, speakers_state], - outputs=[config_status, speakers_dropdown] + outputs=[config_status, speakers_state] ) remove_speaker_btn.click( fn=remove_speaker, inputs=[speakers_dropdown, speakers_state], - outputs=[config_status, speakers_dropdown] + outputs=[config_status, speakers_state] + ) + + # Update the speakers dropdown when speakers_state changes + def update_dropdown(speakers_dict): + return gr.Dropdown(choices=list(speakers_dict.keys()) if speakers_dict else []) + + speakers_state.change( + fn=update_dropdown, + inputs=[speakers_state], + outputs=[speakers_dropdown] ) # Update the speakers dropdown when the tab is selected diff --git a/speakers.yaml b/speakers.yaml index a771b70..f278d55 100644 --- a/speakers.yaml +++ b/speakers.yaml @@ -1,7 +1,11 @@ -Tara: Tara.mp3 -Zac: Zac.mp3 -Leah: Leah.mp3 -Leo: Leo.mp3 Adam: Adam.mp3 Alice: Alice.mp3 -Lewis: Lewis.mp3 \ No newline at end of file +David: speaker_samples/david.mp3 +Debbie: speaker_samples/debbie.mp3 +Denise: speaker_samples/denise.mp3 +Leah: Leah.mp3 +Leo: Leo.mp3 +Lewis: Lewis.mp3 +Maria: speaker_samples/maria.mp3 +Tara: Tara.mp3 +Zac: Zac.mp3