526 lines
22 KiB
Python
526 lines
22 KiB
Python
import gradio as gr
|
|
import yaml
|
|
import os
|
|
import time
|
|
import re
|
|
import tempfile
|
|
import shutil
|
|
import numpy as np
|
|
from chatterbox.tts import ChatterboxTTS
|
|
import torchaudio as ta
|
|
import torch
|
|
from typing import Dict, Tuple, Optional, List
|
|
|
|
# Load speaker options from YAML with error handling
|
|
try:
|
|
yaml_path = os.path.abspath("speakers.yaml")
|
|
if not os.path.exists(yaml_path):
|
|
raise FileNotFoundError(f"speakers.yaml not found at {yaml_path}")
|
|
|
|
with open(yaml_path) as f:
|
|
speakers = yaml.safe_load(f)
|
|
|
|
if not speakers or not isinstance(speakers, dict):
|
|
raise ValueError("speakers.yaml must contain a valid dictionary mapping")
|
|
|
|
except Exception as e:
|
|
raise SystemExit(f"Failed to load speakers.yaml: {str(e)}")
|
|
|
|
def split_text_at_sentence_boundaries(text, max_length=300):
|
|
"""Split text at sentence boundaries, ensuring each chunk is <= max_length."""
|
|
sentence_pattern = r'[.!?](?:\s|$)'
|
|
sentences = re.split(f'({sentence_pattern})', text)
|
|
|
|
actual_sentences = []
|
|
current = ""
|
|
for i in range(0, len(sentences), 2):
|
|
if i+1 < len(sentences):
|
|
current = sentences[i] + sentences[i+1]
|
|
else:
|
|
current = sentences[i]
|
|
if current:
|
|
actual_sentences.append(current)
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
|
|
for sentence in actual_sentences:
|
|
if len(current_chunk) + len(sentence) > max_length and current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence
|
|
else:
|
|
current_chunk += sentence
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
def parse_dialog_line(line):
|
|
"""Parse a dialog line in the format: Name: "Text"""
|
|
pattern = r'^([^:]+):\s*"([^"]+)"$'
|
|
match = re.match(pattern, line.strip())
|
|
if match:
|
|
speaker = match.group(1).strip()
|
|
text = match.group(2).strip()
|
|
return (speaker, text)
|
|
return None
|
|
|
|
def save_speakers_config(speakers_dict: Dict) -> str:
|
|
"""Save speakers configuration to YAML file."""
|
|
try:
|
|
with open("speakers.yaml", 'w') as f:
|
|
yaml.dump(speakers_dict, f)
|
|
return "Speakers configuration saved successfully!"
|
|
except Exception as e:
|
|
return f"Error saving configuration: {str(e)}"
|
|
|
|
def add_speaker(speaker_name: str, audio_file: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
|
"""Add a new speaker with audio sample."""
|
|
if not speaker_name or not audio_file:
|
|
return "Please provide both speaker name and audio file", current_speakers
|
|
|
|
if speaker_name in current_speakers:
|
|
return f"Speaker '{speaker_name}' already exists!", current_speakers
|
|
|
|
# Save the audio file
|
|
speakers_dir = "speaker_samples"
|
|
os.makedirs(speakers_dir, exist_ok=True)
|
|
|
|
# Generate a unique filename
|
|
ext = os.path.splitext(audio_file)[1] or '.wav'
|
|
new_filename = f"{speaker_name.lower().replace(' ', '_')}{ext}"
|
|
new_filepath = os.path.join(speakers_dir, new_filename)
|
|
|
|
# Copy the uploaded file
|
|
shutil.copy2(audio_file, new_filepath)
|
|
|
|
# Update speakers dictionary
|
|
updated_speakers = dict(current_speakers)
|
|
updated_speakers[speaker_name] = new_filepath
|
|
|
|
# Save to YAML
|
|
save_speakers_config(updated_speakers)
|
|
|
|
return f"Speaker '{speaker_name}' added successfully!", updated_speakers
|
|
|
|
def remove_speaker(speaker_name: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
|
"""Remove a speaker from the configuration."""
|
|
if not speaker_name or speaker_name not in current_speakers:
|
|
return "Speaker not found!", current_speakers
|
|
|
|
# Don't actually delete the audio file, just remove from config
|
|
updated_speakers = dict(current_speakers)
|
|
del updated_speakers[speaker_name]
|
|
|
|
# Save to YAML
|
|
save_speakers_config(updated_speakers)
|
|
|
|
return f"Speaker '{speaker_name}' removed!", updated_speakers
|
|
|
|
def update_speakers_dropdown():
|
|
"""Update the speakers dropdown with current speakers."""
|
|
return gr.Dropdown.update(choices=list(speakers.keys()))
|
|
|
|
def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens):
|
|
# Get sample path from selection or upload
|
|
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample
|
|
|
|
if not os.path.exists(sample_path):
|
|
raise gr.Error("Sample file not found!")
|
|
|
|
# Load model (cached automatically by Gradio)
|
|
tts = ChatterboxTTS.from_pretrained(device="mps")
|
|
|
|
# Generate audio with advanced controls
|
|
gen_kwargs = dict(
|
|
text=text,
|
|
audio_prompt_path=sample_path,
|
|
exaggeration=exaggeration,
|
|
cfg_weight=cfg_weight,
|
|
temperature=temperature
|
|
)
|
|
# max_new_tokens is not supported by the current TTS library, so we ignore it here
|
|
wav = tts.generate(**gen_kwargs)
|
|
|
|
# Save with timestamp
|
|
output_path = f"output_{int(time.time())}.wav"
|
|
ta.save(output_path, wav, tts.sr)
|
|
|
|
return output_path, output_path
|
|
|
|
def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()):
|
|
"""Process dialog text and generate audio files."""
|
|
try:
|
|
print("Starting dialog processing...") # Debug log
|
|
print(f"Speaker samples: {speaker_samples}") # Debug log
|
|
|
|
if not dialog_text or not dialog_text.strip():
|
|
return "Error: No dialog text provided", None
|
|
|
|
# Parse dialog lines
|
|
dialog_lines = [line.strip() for line in dialog_text.split('\n') if line.strip()]
|
|
|
|
if not dialog_lines:
|
|
return "Error: No valid dialog lines found", None
|
|
|
|
print(f"Processing {len(dialog_lines)} dialog lines") # Debug log
|
|
|
|
# Initialize model only once if not reinitializing per line
|
|
model = None
|
|
if not reinit_each_line:
|
|
progress(0.1, desc="Loading TTS model...")
|
|
try:
|
|
model = ChatterboxTTS.from_pretrained(device="mps")
|
|
print("TTS model loaded successfully") # Debug log
|
|
except Exception as e:
|
|
return f"Error loading TTS model: {str(e)}", None
|
|
|
|
# Create output directory
|
|
output_dir = "dialog_output"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Process each dialog line
|
|
file_counter = 1
|
|
output_files = []
|
|
summary = []
|
|
|
|
for i, line in enumerate(dialog_lines):
|
|
progress(i / len(dialog_lines), desc=f"Processing line {i+1}/{len(dialog_lines)}")
|
|
print(f"Processing line {i+1}: {line}") # Debug log
|
|
|
|
try:
|
|
parsed = parse_dialog_line(line)
|
|
if not parsed:
|
|
print(f"Skipping line (invalid format): {line}") # Debug log
|
|
continue
|
|
|
|
speaker, text = parsed
|
|
print(f"Found speaker: {speaker}, text: {text[:50]}...") # Debug log
|
|
|
|
if speaker not in speaker_samples:
|
|
msg = f"Skipping unknown speaker: {speaker}"
|
|
print(msg) # Debug log
|
|
summary.append(msg)
|
|
continue
|
|
|
|
sample_path = speaker_samples[speaker]
|
|
if not os.path.exists(sample_path):
|
|
msg = f"Audio sample not found for speaker '{speaker}': {sample_path}"
|
|
print(msg) # Debug log
|
|
summary.append(msg)
|
|
continue
|
|
|
|
if reinit_each_line or model is None:
|
|
print("Initializing new TTS model instance") # Debug log
|
|
model = ChatterboxTTS.from_pretrained(device="mps")
|
|
|
|
if len(text) > 300:
|
|
chunks = split_text_at_sentence_boundaries(text)
|
|
print(f"Splitting long text into {len(chunks)} chunks") # Debug log
|
|
for chunk in chunks:
|
|
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
|
print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log
|
|
try:
|
|
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
|
ta.save(output_file, wav, model.sr)
|
|
output_files.append(output_file)
|
|
summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
|
file_counter += 1
|
|
print(f"Generated audio: {output_file}") # Debug log
|
|
except Exception as e:
|
|
error_msg = f"Error generating audio for chunk: {str(e)}"
|
|
print(error_msg) # Debug log
|
|
summary.append(error_msg)
|
|
continue
|
|
else:
|
|
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
|
print(f"Generating audio: {text[:50]}...") # Debug log
|
|
try:
|
|
wav = model.generate(text, audio_prompt_path=sample_path)
|
|
ta.save(output_file, wav, model.sr)
|
|
output_files.append(output_file)
|
|
summary.append(f"{output_file}: {speaker} - {text[:50]}...")
|
|
file_counter += 1
|
|
print(f"Generated audio: {output_file}") # Debug log
|
|
except Exception as e:
|
|
error_msg = f"Error generating audio: {str(e)}"
|
|
print(error_msg) # Debug log
|
|
summary.append(error_msg)
|
|
continue
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error processing line '{line}': {str(e)}"
|
|
print(error_msg) # Debug log
|
|
summary.append(error_msg)
|
|
continue
|
|
|
|
if not output_files:
|
|
return "Error: No audio files were generated. Check speaker names and audio samples.", None
|
|
|
|
# Concatenate all audio files with 1-second gaps
|
|
concatenated_file = None
|
|
try:
|
|
if len(output_files) > 1:
|
|
print("Concatenating audio files...") # Debug log
|
|
# Load all audio files
|
|
waveforms = []
|
|
sample_rates = set()
|
|
|
|
for file in output_files:
|
|
waveform, sample_rate = ta.load(file)
|
|
waveforms.append(waveform)
|
|
sample_rates.add(sample_rate)
|
|
|
|
if len(sample_rates) != 1:
|
|
raise ValueError(f"Sample rate mismatch: {sample_rates}")
|
|
|
|
sample_rate = sample_rates.pop()
|
|
gap_samples = int(1.0 * sample_rate) # 1 second gap
|
|
gap = torch.zeros(1, gap_samples) # Mono channel
|
|
|
|
# Concatenate waveforms with gaps
|
|
concatenated = waveforms[0]
|
|
for wav in waveforms[1:]:
|
|
concatenated = torch.cat([concatenated, gap, wav], dim=1)
|
|
|
|
# Save concatenated file
|
|
concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav")
|
|
ta.save(concatenated_path, concatenated, sample_rate)
|
|
output_files.append(concatenated_path)
|
|
summary.append(f"\nConcatenated file: {concatenated_path}")
|
|
concatenated_file = concatenated_path
|
|
print(f"Created concatenated file: {concatenated_path}")
|
|
|
|
# Create a zip file of all outputs
|
|
import zipfile
|
|
zip_path = os.path.join(output_dir, f"{output_base}.zip")
|
|
print(f"Creating zip file: {zip_path}") # Debug log
|
|
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
|
for file in output_files:
|
|
zipf.write(file, os.path.basename(file))
|
|
print(f"Zip file created successfully with {len(output_files)} files") # Debug log
|
|
|
|
# Return both the zip and the concatenated file if it exists
|
|
if concatenated_file:
|
|
return "\n".join(summary), concatenated_file, zip_path
|
|
return "\n".join(summary), None, zip_path
|
|
except Exception as e:
|
|
error_msg = f"Error creating zip file: {str(e)}"
|
|
print(error_msg) # Debug log
|
|
return "\n".join(summary + [error_msg]), None
|
|
|
|
except Exception as e:
|
|
error_msg = f"Unexpected error: {str(e)}"
|
|
print(error_msg) # Debug log
|
|
import traceback
|
|
traceback.print_exc() # Print full traceback
|
|
return error_msg, None
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("# Chatterbox TTS Generator")
|
|
|
|
# Store speakers in a global state
|
|
speakers_state = gr.State(speakers)
|
|
|
|
with gr.Tabs() as tabs:
|
|
with gr.TabItem("Single Utterance"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
speaker_dropdown = gr.Dropdown(
|
|
choices=["Custom"] + list(speakers.keys()),
|
|
value="Custom",
|
|
label="Select Speaker"
|
|
)
|
|
custom_upload = gr.Audio(
|
|
label="Or upload custom speaker sample",
|
|
type="filepath",
|
|
visible=True
|
|
)
|
|
text_input = gr.Textbox(
|
|
label="Text to synthesize",
|
|
placeholder="Enter text here...",
|
|
lines=3
|
|
)
|
|
exaggeration_slider = gr.Slider(
|
|
minimum=0.0, maximum=2.0, value=0.5, step=0.01,
|
|
label="Exaggeration (emotion)",
|
|
info="Controls expressiveness. 0.5 = neutral, higher = more expressive."
|
|
)
|
|
cfg_weight_slider = gr.Slider(
|
|
minimum=0.0, maximum=2.0, value=0.5, step=0.01,
|
|
label="CFG Weight",
|
|
info="Higher = more faithful to text, lower = more like reference voice."
|
|
)
|
|
temperature_slider = gr.Slider(
|
|
minimum=0.1, maximum=2.0, value=0.8, step=0.01,
|
|
label="Temperature",
|
|
info="Controls randomness. Higher = more variation."
|
|
)
|
|
max_new_tokens_box = gr.Number(
|
|
value=1000,
|
|
label="Max New Tokens (advanced)",
|
|
precision=0,
|
|
info="Maximum audio tokens to generate. Increase for longer texts."
|
|
)
|
|
generate_btn = gr.Button("Generate Speech")
|
|
|
|
with gr.Column():
|
|
audio_output = gr.Audio(label="Generated Speech")
|
|
download = gr.File(label="Download WAV")
|
|
|
|
gr.Examples(
|
|
examples=[
|
|
["Hello world! This is a demo.", "Tara"],
|
|
["Welcome to the future of text-to-speech.", "Zac"]
|
|
],
|
|
inputs=[text_input, speaker_dropdown]
|
|
)
|
|
|
|
generate_btn.click(
|
|
fn=generate_audio,
|
|
inputs=[speaker_dropdown, custom_upload, text_input, exaggeration_slider,
|
|
cfg_weight_slider, temperature_slider, max_new_tokens_box],
|
|
outputs=[audio_output, download]
|
|
)
|
|
|
|
with gr.TabItem("Dialog Generation"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Row():
|
|
with gr.Column(scale=2):
|
|
dialog_text = gr.Textbox(
|
|
label="Dialog Text",
|
|
placeholder='''Enter dialog in format:
|
|
Speaker1: "Hello, how are you?"
|
|
Speaker2: "I'm doing well, thank you!"
|
|
Speaker1: "What are your plans for today?"
|
|
Speaker2: "I'm working on a new project."''',
|
|
lines=10
|
|
)
|
|
with gr.Column(scale=1):
|
|
with gr.Group():
|
|
gr.Markdown("### Speaker Configuration")
|
|
with gr.Row():
|
|
new_speaker_name = gr.Textbox(
|
|
label="New Speaker Name",
|
|
placeholder="Enter speaker name"
|
|
)
|
|
new_speaker_audio = gr.Audio(
|
|
label="Speaker Sample",
|
|
type="filepath"
|
|
)
|
|
with gr.Row():
|
|
add_speaker_btn = gr.Button("Add Speaker")
|
|
remove_speaker_btn = gr.Button("Remove Selected")
|
|
|
|
speakers_dropdown = gr.Dropdown(
|
|
label="Available Speakers",
|
|
choices=list(speakers.keys()) if speakers else [],
|
|
interactive=True,
|
|
multiselect=True
|
|
)
|
|
|
|
gr.Markdown("### Generation Settings")
|
|
with gr.Row():
|
|
output_base = gr.Textbox(
|
|
label="Output Base Name",
|
|
value="dialog_output",
|
|
placeholder="base_name (will generate 001-base_name.wav, etc.)"
|
|
)
|
|
reinit_each_line = gr.Checkbox(
|
|
label="Re-initialize model each line",
|
|
value=False,
|
|
info="Reduces memory usage but is slower"
|
|
)
|
|
|
|
config_status = gr.Textbox(
|
|
label="Status",
|
|
interactive=False,
|
|
visible=True
|
|
)
|
|
|
|
dialog_generate_btn = gr.Button("Generate Dialog")
|
|
|
|
with gr.Column():
|
|
dialog_output = gr.Textbox(
|
|
label="Generation Log",
|
|
interactive=False,
|
|
lines=15
|
|
)
|
|
concatenated_audio = gr.Audio(
|
|
label="Concatenated Audio",
|
|
visible=False
|
|
)
|
|
dialog_download = gr.File(
|
|
label="Download All Files",
|
|
visible=False
|
|
)
|
|
|
|
# Event handlers
|
|
add_speaker_btn.click(
|
|
fn=add_speaker,
|
|
inputs=[new_speaker_name, new_speaker_audio, speakers_state],
|
|
outputs=[config_status, speakers_dropdown]
|
|
)
|
|
|
|
remove_speaker_btn.click(
|
|
fn=remove_speaker,
|
|
inputs=[speakers_dropdown, speakers_state],
|
|
outputs=[config_status, speakers_dropdown]
|
|
)
|
|
|
|
# Update the speakers dropdown when the tab is selected
|
|
def on_tab_change():
|
|
# Return a dictionary that updates the dropdown choices
|
|
return {"__type__": "update", "choices": list(speakers.keys())}
|
|
|
|
tabs.select(
|
|
fn=on_tab_change,
|
|
inputs=[],
|
|
outputs=[speakers_dropdown]
|
|
)
|
|
|
|
def update_outputs(*args):
|
|
result = process_dialog(*args)
|
|
if len(result) == 3:
|
|
summary, concat_file, zip_file = result
|
|
if concat_file:
|
|
# When we have a concatenated file, show both the audio player and download
|
|
return [
|
|
summary, # dialog_output
|
|
gr.Audio(value=concat_file, visible=True), # concatenated_audio
|
|
gr.File(value=zip_file, visible=True) # dialog_download
|
|
]
|
|
# When no concatenated file, just show the zip download
|
|
return [
|
|
summary,
|
|
gr.Audio(visible=False),
|
|
gr.File(value=zip_file, visible=True)
|
|
]
|
|
# Error case
|
|
return [
|
|
result[0], # error message
|
|
gr.Audio(visible=False),
|
|
gr.File(visible=False)
|
|
]
|
|
|
|
# Update the click handler with the correct number of outputs
|
|
dialog_generate_btn.click(
|
|
fn=update_outputs,
|
|
inputs=[
|
|
dialog_text,
|
|
speakers_state, # Pass the current speakers dict
|
|
output_base,
|
|
reinit_each_line
|
|
],
|
|
outputs=[
|
|
dialog_output, # Text output
|
|
concatenated_audio, # Audio component
|
|
dialog_download # Zip file
|
|
]
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(share=True)
|