import gradio as gr import yaml import os import time import re import tempfile import shutil import numpy as np from chatterbox.tts import ChatterboxTTS import torchaudio as ta import torch from typing import Dict, Tuple, Optional, List # Load speaker options from YAML with error handling try: yaml_path = os.path.abspath("speakers.yaml") if not os.path.exists(yaml_path): raise FileNotFoundError(f"speakers.yaml not found at {yaml_path}") with open(yaml_path) as f: speakers = yaml.safe_load(f) if not speakers or not isinstance(speakers, dict): raise ValueError("speakers.yaml must contain a valid dictionary mapping") except Exception as e: raise SystemExit(f"Failed to load speakers.yaml: {str(e)}") def split_text_at_sentence_boundaries(text, max_length=300): """Split text at sentence boundaries, ensuring each chunk is <= max_length.""" sentence_pattern = r'[.!?](?:\s|$)' sentences = re.split(f'({sentence_pattern})', text) actual_sentences = [] current = "" for i in range(0, len(sentences), 2): if i+1 < len(sentences): current = sentences[i] + sentences[i+1] else: current = sentences[i] if current: actual_sentences.append(current) chunks = [] current_chunk = "" for sentence in actual_sentences: if len(current_chunk) + len(sentence) > max_length and current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence else: current_chunk += sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks def parse_dialog_line(line): """Parse a dialog line to extract speaker and text.""" # Handle silence lines silence_match = re.match(r'^Silence:\s*([0-9]*\.?[0-9]+)$', line.strip(), re.IGNORECASE) if silence_match: duration = float(silence_match.group(1)) return "SILENCE", duration # Handle regular dialog lines match = re.match(r'^([^:]+):\s*["\'](.+)["\']$', line.strip()) if match: speaker = match.group(1).strip() text = match.group(2).strip() return speaker, text return None def save_speakers_config(speakers_dict: Dict) -> str: """Save speakers configuration to YAML file.""" try: with open("speakers.yaml", 'w') as f: yaml.dump(speakers_dict, f) return "Speakers configuration saved successfully!" except Exception as e: return f"Error saving configuration: {str(e)}" def add_speaker(speaker_name: str, audio_file: str, current_speakers: Dict) -> Tuple[str, Dict]: """Add a new speaker with audio sample.""" if not speaker_name or not audio_file: return "Please provide both speaker name and audio file", current_speakers if speaker_name in current_speakers: return f"Speaker '{speaker_name}' already exists!", current_speakers # Save the audio file speakers_dir = "speaker_samples" os.makedirs(speakers_dir, exist_ok=True) # Generate a unique filename ext = os.path.splitext(audio_file)[1] or '.wav' new_filename = f"{speaker_name.lower().replace(' ', '_')}{ext}" new_filepath = os.path.join(speakers_dir, new_filename) # Copy the uploaded file shutil.copy2(audio_file, new_filepath) # Update speakers dictionary updated_speakers = dict(current_speakers) updated_speakers[speaker_name] = new_filepath # Save to YAML save_speakers_config(updated_speakers) return f"Speaker '{speaker_name}' added successfully!", updated_speakers def remove_speaker(speaker_name: str, current_speakers: Dict) -> Tuple[str, Dict]: """Remove a speaker from the configuration.""" if not speaker_name or speaker_name not in current_speakers: return "Speaker not found!", current_speakers # Don't actually delete the audio file, just remove from config updated_speakers = dict(current_speakers) del updated_speakers[speaker_name] # Save to YAML save_speakers_config(updated_speakers) return f"Speaker '{speaker_name}' removed!", updated_speakers def update_speakers_dropdown(): """Update the speakers dropdown with current speakers.""" return gr.Dropdown.update(choices=list(speakers.keys())) def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens): import torch # Ensure torch is available in function scope tts = None wav = None try: # Get sample path from selection or upload sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample if not os.path.exists(sample_path): raise gr.Error("Sample file not found!") # Load model (cached automatically by Gradio) tts = ChatterboxTTS.from_pretrained(device="mps") # Generate audio with advanced controls - disable gradients for inference with torch.no_grad(): gen_kwargs = dict( text=text, audio_prompt_path=sample_path, exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature ) # max_new_tokens is not supported by the current TTS library, so we ignore it here wav = tts.generate(**gen_kwargs) # Create output directory for single utterances output_dir = "single_output" os.makedirs(output_dir, exist_ok=True) # Save with timestamp in the subdirectory output_path = os.path.join(output_dir, f"output_{int(time.time())}.wav") ta.save(output_path, wav, tts.sr) return output_path, output_path finally: # Clean up audio tensor first if wav is not None: try: del wav except: pass # Clean up model to free memory if tts is not None: print("Cleaning up TTS model to free memory...") # Debug log try: # Move model to CPU and delete reference if hasattr(tts, 'cpu'): tts.cpu() del tts # Force garbage collection import gc gc.collect() # Clear GPU cache - both CUDA and MPS try: if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear MPS cache for Apple Silicon if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.empty_cache() except: pass print("Model cleanup completed") # Debug log except Exception as cleanup_error: print(f"Error during model cleanup: {cleanup_error}") # Debug log def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()): """Process dialog text and generate audio files.""" import torch # Ensure torch is available in function scope model = None try: print("Starting dialog processing...") # Debug log print(f"Speaker samples: {speaker_samples}") # Debug log if not dialog_text or not dialog_text.strip(): return "Error: No dialog text provided", None # Parse dialog lines dialog_lines = [line.strip() for line in dialog_text.split('\n') if line.strip()] if not dialog_lines: return "Error: No valid dialog lines found", None print(f"Processing {len(dialog_lines)} dialog lines") # Debug log # Initialize model only once if not reinitializing per line if not reinit_each_line: progress(0.1, desc="Loading TTS model...") try: model = ChatterboxTTS.from_pretrained(device="mps") print("TTS model loaded successfully") # Debug log except Exception as e: return f"Error loading TTS model: {str(e)}", None # Create output directory output_dir = "dialog_output" os.makedirs(output_dir, exist_ok=True) # Process each dialog line file_counter = 1 output_files = [] silence_durations = [] # Track silence durations for concatenation summary = [] for i, line in enumerate(dialog_lines): progress(i / len(dialog_lines), desc=f"Processing line {i+1}/{len(dialog_lines)}") print(f"Processing line {i+1}: {line}") # Debug log try: parsed = parse_dialog_line(line) if not parsed: print(f"Skipping line (invalid format): {line}") # Debug log continue speaker, text_or_duration = parsed print(f"Found speaker: {speaker}, content: {text_or_duration}") # Debug log # Handle silence entries if speaker == "SILENCE": duration = text_or_duration print(f"Adding silence duration: {duration} seconds") # Debug log silence_durations.append(duration) summary.append(f"Silence: {duration} seconds") continue # Handle regular dialog text = text_or_duration if speaker not in speaker_samples: msg = f"Skipping unknown speaker: {speaker}" print(msg) # Debug log summary.append(msg) # Add default silence duration if we skip a speaker silence_durations.append(1.0) continue sample_path = speaker_samples[speaker] if not os.path.exists(sample_path): msg = f"Audio sample not found for speaker '{speaker}': {sample_path}" print(msg) # Debug log summary.append(msg) # Add default silence duration if we skip due to missing sample silence_durations.append(1.0) continue if reinit_each_line or model is None: print("Initializing new TTS model instance") # Debug log model = ChatterboxTTS.from_pretrained(device="mps") if len(text) > 300: chunks = split_text_at_sentence_boundaries(text) print(f"Splitting long text into {len(chunks)} chunks") # Debug log for chunk in chunks: output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log try: with torch.no_grad(): wav = model.generate(chunk, audio_prompt_path=sample_path) ta.save(output_file, wav, model.sr) # Clean up wav tensor immediately del wav output_files.append(output_file) summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...") file_counter += 1 print(f"Generated audio: {output_file}") # Debug log except Exception as e: error_msg = f"Error generating audio for chunk: {str(e)}" print(error_msg) # Debug log summary.append(error_msg) continue else: output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav") print(f"Generating audio: {text[:50]}...") # Debug log try: with torch.no_grad(): wav = model.generate(text, audio_prompt_path=sample_path) ta.save(output_file, wav, model.sr) # Clean up wav tensor immediately del wav output_files.append(output_file) summary.append(f"{output_file}: {speaker} - {text[:50]}...") file_counter += 1 print(f"Generated audio: {output_file}") # Debug log except Exception as e: error_msg = f"Error generating audio: {str(e)}" print(error_msg) # Debug log summary.append(error_msg) continue except Exception as e: error_msg = f"Error processing line '{line}': {str(e)}" print(error_msg) # Debug log summary.append(error_msg) continue if not output_files: return "Error: No audio files were generated. Check speaker names and audio samples.", None # Concatenate all audio files with configurable gaps concatenated_file = None waveforms = [] concatenated = None try: if len(output_files) > 1: print("Concatenating audio files...") # Debug log # Load all audio files sample_rates = set() for file in output_files: waveform, sample_rate = ta.load(file) waveforms.append(waveform) sample_rates.add(sample_rate) if len(sample_rates) != 1: raise ValueError(f"Sample rate mismatch: {sample_rates}") sample_rate = sample_rates.pop() # Start with the first audio segment concatenated = waveforms[0] # Add subsequent segments with appropriate silence gaps for i, wav in enumerate(waveforms[1:], 1): # Determine silence duration for this gap if i-1 < len(silence_durations): gap_duration = silence_durations[i-1] else: gap_duration = 1.0 # Default 1 second gap_samples = int(gap_duration * sample_rate) gap = torch.zeros(1, gap_samples) # Mono channel print(f"Adding {gap_duration}s gap before segment {i+1}") # Debug log concatenated = torch.cat([concatenated, gap, wav], dim=1) # Clean up gap tensor immediately del gap # Save concatenated file concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav") ta.save(concatenated_path, concatenated, sample_rate) output_files.append(concatenated_path) summary.append(f"\nConcatenated file: {concatenated_path}") concatenated_file = concatenated_path print(f"Created concatenated file: {concatenated_path}") # Create a zip file of all outputs import zipfile zip_path = os.path.join(output_dir, f"{output_base}.zip") print(f"Creating zip file: {zip_path}") # Debug log with zipfile.ZipFile(zip_path, 'w') as zipf: for file in output_files: zipf.write(file, os.path.basename(file)) print(f"Zip file created successfully with {len(output_files)} files") # Debug log # Return both the zip and the concatenated file if it exists if concatenated_file: return "\n".join(summary), concatenated_file, zip_path return "\n".join(summary), None, zip_path except Exception as e: error_msg = f"Error creating zip file: {str(e)}" print(error_msg) # Debug log return "\n".join(summary + [error_msg]), None finally: # Clean up concatenation tensors try: if waveforms: for waveform in waveforms: del waveform waveforms.clear() if concatenated is not None: del concatenated # Force garbage collection import gc gc.collect() print("Concatenation tensors cleaned up") # Debug log except Exception as cleanup_error: print(f"Error during concatenation cleanup: {cleanup_error}") # Debug log except Exception as e: error_msg = f"Unexpected error: {str(e)}" print(error_msg) # Debug log import traceback traceback.print_exc() # Print full traceback return error_msg, None finally: # Clean up model to free memory if 'model' in locals() and model is not None: print("Cleaning up TTS model to free memory...") # Debug log try: # Move model to CPU and delete reference if hasattr(model, 'cpu'): model.cpu() del model # Force garbage collection import gc gc.collect() # Clear GPU cache - both CUDA and MPS try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear MPS cache for Apple Silicon if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.empty_cache() except: pass print("Model cleanup completed") # Debug log except Exception as cleanup_error: print(f"Error during model cleanup: {cleanup_error}") # Debug log with gr.Blocks() as demo: gr.Markdown("# Chatterbox TTS Generator") # Store speakers in a global state speakers_state = gr.State(speakers) with gr.Tabs() as tabs: with gr.TabItem("Single Utterance"): with gr.Row(): with gr.Column(): speaker_dropdown = gr.Dropdown( choices=["Custom"] + list(speakers.keys()), value="Custom", label="Select Speaker" ) custom_upload = gr.Audio( label="Or upload custom speaker sample", type="filepath", visible=True ) text_input = gr.Textbox( label="Text to synthesize", placeholder="Enter text here...", lines=3 ) exaggeration_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.5, step=0.01, label="Exaggeration (emotion)", info="Controls expressiveness. 0.5 = neutral, higher = more expressive." ) cfg_weight_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.5, step=0.01, label="CFG Weight", info="Higher = more faithful to text, lower = more like reference voice." ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.01, label="Temperature", info="Controls randomness. Higher = more variation." ) max_new_tokens_box = gr.Number( value=1000, label="Max New Tokens (advanced)", precision=0, info="Maximum audio tokens to generate. Increase for longer texts." ) generate_btn = gr.Button("Generate Speech") with gr.Column(): audio_output = gr.Audio(label="Generated Speech") download = gr.File(label="Download WAV") gr.Examples( examples=[ ["Hello world! This is a demo.", "Tara"], ["Welcome to the future of text-to-speech.", "Zac"] ], inputs=[text_input, speaker_dropdown] ) generate_btn.click( fn=generate_audio, inputs=[speaker_dropdown, custom_upload, text_input, exaggeration_slider, cfg_weight_slider, temperature_slider, max_new_tokens_box], outputs=[audio_output, download] ) with gr.TabItem("Dialog Generation"): with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(scale=2): dialog_text = gr.Textbox( label="Dialog Text", placeholder='''Enter dialog in format: Speaker1: "Hello, how are you?" Speaker2: "I'm doing well, thank you!" Speaker1: "What are your plans for today?" Speaker2: "I'm working on a new project."''', lines=10 ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Speaker Configuration") with gr.Row(): new_speaker_name = gr.Textbox( label="New Speaker Name", placeholder="Enter speaker name" ) new_speaker_audio = gr.Audio( label="Speaker Sample", type="filepath" ) with gr.Row(): add_speaker_btn = gr.Button("Add Speaker") remove_speaker_btn = gr.Button("Remove Selected") speakers_dropdown = gr.Dropdown( label="Available Speakers", choices=list(speakers.keys()) if speakers else [], interactive=True, multiselect=True ) gr.Markdown("### Generation Settings") with gr.Row(): output_base = gr.Textbox( label="Output Base Name", value="dialog_output", placeholder="base_name (will generate 001-base_name.wav, etc.)" ) reinit_each_line = gr.Checkbox( label="Re-initialize model each line", value=False, info="Reduces memory usage but is slower" ) config_status = gr.Textbox( label="Status", interactive=False, visible=True ) dialog_generate_btn = gr.Button("Generate Dialog") with gr.Column(): dialog_output = gr.Textbox( label="Generation Log", interactive=False, lines=15 ) concatenated_audio = gr.Audio( label="Concatenated Audio", visible=False ) dialog_download = gr.File( label="Download All Files", visible=False ) # Event handlers add_speaker_btn.click( fn=add_speaker, inputs=[new_speaker_name, new_speaker_audio, speakers_state], outputs=[config_status, speakers_state] ) remove_speaker_btn.click( fn=remove_speaker, inputs=[speakers_dropdown, speakers_state], outputs=[config_status, speakers_state] ) # Update the speakers dropdown when speakers_state changes def update_dropdown(speakers_dict): return gr.Dropdown(choices=list(speakers_dict.keys()) if speakers_dict else []) speakers_state.change( fn=update_dropdown, inputs=[speakers_state], outputs=[speakers_dropdown] ) # Update the speakers dropdown when the tab is selected def on_tab_change(): # Return a dictionary that updates the dropdown choices return {"__type__": "update", "choices": list(speakers.keys())} tabs.select( fn=on_tab_change, inputs=[], outputs=[speakers_dropdown] ) def update_outputs(*args): result = process_dialog(*args) if len(result) == 3: summary, concat_file, zip_file = result if concat_file: # When we have a concatenated file, show both the audio player and download return [ summary, # dialog_output gr.Audio(value=concat_file, visible=True), # concatenated_audio gr.File(value=zip_file, visible=True) # dialog_download ] # When no concatenated file, just show the zip download return [ summary, gr.Audio(visible=False), gr.File(value=zip_file, visible=True) ] # Error case return [ result[0], # error message gr.Audio(visible=False), gr.File(visible=False) ] # Update the click handler with the correct number of outputs dialog_generate_btn.click( fn=update_outputs, inputs=[ dialog_text, speakers_state, # Pass the current speakers dict output_base, reinit_each_line ], outputs=[ dialog_output, # Text output concatenated_audio, # Audio component dialog_download # Zip file ] ) if __name__ == "__main__": demo.launch(share=True)