Added dialog generation with concatenation
This commit is contained in:
parent
63efb26910
commit
869914e8a0
|
@ -3,3 +3,5 @@
|
|||
output*.wav
|
||||
*.wav
|
||||
*.mp3
|
||||
dialog_output/
|
||||
*.zip
|
||||
|
|
414
gradio_app.py
414
gradio_app.py
|
@ -2,8 +2,14 @@ import gradio as gr
|
|||
import yaml
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
import torchaudio as ta
|
||||
import torch
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
|
||||
# Load speaker options from YAML with error handling
|
||||
try:
|
||||
|
@ -20,6 +26,102 @@ try:
|
|||
except Exception as e:
|
||||
raise SystemExit(f"Failed to load speakers.yaml: {str(e)}")
|
||||
|
||||
def split_text_at_sentence_boundaries(text, max_length=300):
|
||||
"""Split text at sentence boundaries, ensuring each chunk is <= max_length."""
|
||||
sentence_pattern = r'[.!?](?:\s|$)'
|
||||
sentences = re.split(f'({sentence_pattern})', text)
|
||||
|
||||
actual_sentences = []
|
||||
current = ""
|
||||
for i in range(0, len(sentences), 2):
|
||||
if i+1 < len(sentences):
|
||||
current = sentences[i] + sentences[i+1]
|
||||
else:
|
||||
current = sentences[i]
|
||||
if current:
|
||||
actual_sentences.append(current)
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for sentence in actual_sentences:
|
||||
if len(current_chunk) + len(sentence) > max_length and current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence
|
||||
else:
|
||||
current_chunk += sentence
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def parse_dialog_line(line):
|
||||
"""Parse a dialog line in the format: Name: "Text"""
|
||||
pattern = r'^([^:]+):\s*"([^"]+)"$'
|
||||
match = re.match(pattern, line.strip())
|
||||
if match:
|
||||
speaker = match.group(1).strip()
|
||||
text = match.group(2).strip()
|
||||
return (speaker, text)
|
||||
return None
|
||||
|
||||
def save_speakers_config(speakers_dict: Dict) -> str:
|
||||
"""Save speakers configuration to YAML file."""
|
||||
try:
|
||||
with open("speakers.yaml", 'w') as f:
|
||||
yaml.dump(speakers_dict, f)
|
||||
return "Speakers configuration saved successfully!"
|
||||
except Exception as e:
|
||||
return f"Error saving configuration: {str(e)}"
|
||||
|
||||
def add_speaker(speaker_name: str, audio_file: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
||||
"""Add a new speaker with audio sample."""
|
||||
if not speaker_name or not audio_file:
|
||||
return "Please provide both speaker name and audio file", current_speakers
|
||||
|
||||
if speaker_name in current_speakers:
|
||||
return f"Speaker '{speaker_name}' already exists!", current_speakers
|
||||
|
||||
# Save the audio file
|
||||
speakers_dir = "speaker_samples"
|
||||
os.makedirs(speakers_dir, exist_ok=True)
|
||||
|
||||
# Generate a unique filename
|
||||
ext = os.path.splitext(audio_file)[1] or '.wav'
|
||||
new_filename = f"{speaker_name.lower().replace(' ', '_')}{ext}"
|
||||
new_filepath = os.path.join(speakers_dir, new_filename)
|
||||
|
||||
# Copy the uploaded file
|
||||
shutil.copy2(audio_file, new_filepath)
|
||||
|
||||
# Update speakers dictionary
|
||||
updated_speakers = dict(current_speakers)
|
||||
updated_speakers[speaker_name] = new_filepath
|
||||
|
||||
# Save to YAML
|
||||
save_speakers_config(updated_speakers)
|
||||
|
||||
return f"Speaker '{speaker_name}' added successfully!", updated_speakers
|
||||
|
||||
def remove_speaker(speaker_name: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
||||
"""Remove a speaker from the configuration."""
|
||||
if not speaker_name or speaker_name not in current_speakers:
|
||||
return "Speaker not found!", current_speakers
|
||||
|
||||
# Don't actually delete the audio file, just remove from config
|
||||
updated_speakers = dict(current_speakers)
|
||||
del updated_speakers[speaker_name]
|
||||
|
||||
# Save to YAML
|
||||
save_speakers_config(updated_speakers)
|
||||
|
||||
return f"Speaker '{speaker_name}' removed!", updated_speakers
|
||||
|
||||
def update_speakers_dropdown():
|
||||
"""Update the speakers dropdown with current speakers."""
|
||||
return gr.Dropdown.update(choices=list(speakers.keys()))
|
||||
|
||||
def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens):
|
||||
# Get sample path from selection or upload
|
||||
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample
|
||||
|
@ -47,10 +149,182 @@ def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight
|
|||
|
||||
return output_path, output_path
|
||||
|
||||
def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()):
|
||||
"""Process dialog text and generate audio files."""
|
||||
try:
|
||||
print("Starting dialog processing...") # Debug log
|
||||
print(f"Speaker samples: {speaker_samples}") # Debug log
|
||||
|
||||
if not dialog_text or not dialog_text.strip():
|
||||
return "Error: No dialog text provided", None
|
||||
|
||||
# Parse dialog lines
|
||||
dialog_lines = [line.strip() for line in dialog_text.split('\n') if line.strip()]
|
||||
|
||||
if not dialog_lines:
|
||||
return "Error: No valid dialog lines found", None
|
||||
|
||||
print(f"Processing {len(dialog_lines)} dialog lines") # Debug log
|
||||
|
||||
# Initialize model only once if not reinitializing per line
|
||||
model = None
|
||||
if not reinit_each_line:
|
||||
progress(0.1, desc="Loading TTS model...")
|
||||
try:
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
print("TTS model loaded successfully") # Debug log
|
||||
except Exception as e:
|
||||
return f"Error loading TTS model: {str(e)}", None
|
||||
|
||||
# Create output directory
|
||||
output_dir = "dialog_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Process each dialog line
|
||||
file_counter = 1
|
||||
output_files = []
|
||||
summary = []
|
||||
|
||||
for i, line in enumerate(dialog_lines):
|
||||
progress(i / len(dialog_lines), desc=f"Processing line {i+1}/{len(dialog_lines)}")
|
||||
print(f"Processing line {i+1}: {line}") # Debug log
|
||||
|
||||
try:
|
||||
parsed = parse_dialog_line(line)
|
||||
if not parsed:
|
||||
print(f"Skipping line (invalid format): {line}") # Debug log
|
||||
continue
|
||||
|
||||
speaker, text = parsed
|
||||
print(f"Found speaker: {speaker}, text: {text[:50]}...") # Debug log
|
||||
|
||||
if speaker not in speaker_samples:
|
||||
msg = f"Skipping unknown speaker: {speaker}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
continue
|
||||
|
||||
sample_path = speaker_samples[speaker]
|
||||
if not os.path.exists(sample_path):
|
||||
msg = f"Audio sample not found for speaker '{speaker}': {sample_path}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
continue
|
||||
|
||||
if reinit_each_line or model is None:
|
||||
print("Initializing new TTS model instance") # Debug log
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
if len(text) > 300:
|
||||
chunks = split_text_at_sentence_boundaries(text)
|
||||
print(f"Splitting long text into {len(chunks)} chunks") # Debug log
|
||||
for chunk in chunks:
|
||||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log
|
||||
try:
|
||||
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
||||
file_counter += 1
|
||||
print(f"Generated audio: {output_file}") # Debug log
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating audio for chunk: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
else:
|
||||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio: {text[:50]}...") # Debug log
|
||||
try:
|
||||
wav = model.generate(text, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} - {text[:50]}...")
|
||||
file_counter += 1
|
||||
print(f"Generated audio: {output_file}") # Debug log
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating audio: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing line '{line}': {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
|
||||
if not output_files:
|
||||
return "Error: No audio files were generated. Check speaker names and audio samples.", None
|
||||
|
||||
# Concatenate all audio files with 1-second gaps
|
||||
concatenated_file = None
|
||||
try:
|
||||
if len(output_files) > 1:
|
||||
print("Concatenating audio files...") # Debug log
|
||||
# Load all audio files
|
||||
waveforms = []
|
||||
sample_rates = set()
|
||||
|
||||
for file in output_files:
|
||||
waveform, sample_rate = ta.load(file)
|
||||
waveforms.append(waveform)
|
||||
sample_rates.add(sample_rate)
|
||||
|
||||
if len(sample_rates) != 1:
|
||||
raise ValueError(f"Sample rate mismatch: {sample_rates}")
|
||||
|
||||
sample_rate = sample_rates.pop()
|
||||
gap_samples = int(1.0 * sample_rate) # 1 second gap
|
||||
gap = torch.zeros(1, gap_samples) # Mono channel
|
||||
|
||||
# Concatenate waveforms with gaps
|
||||
concatenated = waveforms[0]
|
||||
for wav in waveforms[1:]:
|
||||
concatenated = torch.cat([concatenated, gap, wav], dim=1)
|
||||
|
||||
# Save concatenated file
|
||||
concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav")
|
||||
ta.save(concatenated_path, concatenated, sample_rate)
|
||||
output_files.append(concatenated_path)
|
||||
summary.append(f"\nConcatenated file: {concatenated_path}")
|
||||
concatenated_file = concatenated_path
|
||||
print(f"Created concatenated file: {concatenated_path}")
|
||||
|
||||
# Create a zip file of all outputs
|
||||
import zipfile
|
||||
zip_path = os.path.join(output_dir, f"{output_base}.zip")
|
||||
print(f"Creating zip file: {zip_path}") # Debug log
|
||||
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
||||
for file in output_files:
|
||||
zipf.write(file, os.path.basename(file))
|
||||
print(f"Zip file created successfully with {len(output_files)} files") # Debug log
|
||||
|
||||
# Return both the zip and the concatenated file if it exists
|
||||
if concatenated_file:
|
||||
return "\n".join(summary), concatenated_file, zip_path
|
||||
return "\n".join(summary), None, zip_path
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating zip file: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
return "\n".join(summary + [error_msg]), None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
import traceback
|
||||
traceback.print_exc() # Print full traceback
|
||||
return error_msg, None
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# Chatterbox TTS Generator")
|
||||
|
||||
# Store speakers in a global state
|
||||
speakers_state = gr.State(speakers)
|
||||
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Single Utterance"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
speaker_dropdown = gr.Dropdown(
|
||||
|
@ -105,9 +379,147 @@ with gr.Blocks() as demo:
|
|||
|
||||
generate_btn.click(
|
||||
fn=generate_audio,
|
||||
inputs=[speaker_dropdown, custom_upload, text_input, exaggeration_slider, cfg_weight_slider, temperature_slider, max_new_tokens_box],
|
||||
inputs=[speaker_dropdown, custom_upload, text_input, exaggeration_slider,
|
||||
cfg_weight_slider, temperature_slider, max_new_tokens_box],
|
||||
outputs=[audio_output, download]
|
||||
)
|
||||
|
||||
with gr.TabItem("Dialog Generation"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
dialog_text = gr.Textbox(
|
||||
label="Dialog Text",
|
||||
placeholder='''Enter dialog in format:
|
||||
Speaker1: "Hello, how are you?"
|
||||
Speaker2: "I'm doing well, thank you!"
|
||||
Speaker1: "What are your plans for today?"
|
||||
Speaker2: "I'm working on a new project."''',
|
||||
lines=10
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown("### Speaker Configuration")
|
||||
with gr.Row():
|
||||
new_speaker_name = gr.Textbox(
|
||||
label="New Speaker Name",
|
||||
placeholder="Enter speaker name"
|
||||
)
|
||||
new_speaker_audio = gr.Audio(
|
||||
label="Speaker Sample",
|
||||
type="filepath"
|
||||
)
|
||||
with gr.Row():
|
||||
add_speaker_btn = gr.Button("Add Speaker")
|
||||
remove_speaker_btn = gr.Button("Remove Selected")
|
||||
|
||||
speakers_dropdown = gr.Dropdown(
|
||||
label="Available Speakers",
|
||||
choices=list(speakers.keys()) if speakers else [],
|
||||
interactive=True,
|
||||
multiselect=True
|
||||
)
|
||||
|
||||
gr.Markdown("### Generation Settings")
|
||||
with gr.Row():
|
||||
output_base = gr.Textbox(
|
||||
label="Output Base Name",
|
||||
value="dialog_output",
|
||||
placeholder="base_name (will generate 001-base_name.wav, etc.)"
|
||||
)
|
||||
reinit_each_line = gr.Checkbox(
|
||||
label="Re-initialize model each line",
|
||||
value=False,
|
||||
info="Reduces memory usage but is slower"
|
||||
)
|
||||
|
||||
config_status = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
visible=True
|
||||
)
|
||||
|
||||
dialog_generate_btn = gr.Button("Generate Dialog")
|
||||
|
||||
with gr.Column():
|
||||
dialog_output = gr.Textbox(
|
||||
label="Generation Log",
|
||||
interactive=False,
|
||||
lines=15
|
||||
)
|
||||
concatenated_audio = gr.Audio(
|
||||
label="Concatenated Audio",
|
||||
visible=False
|
||||
)
|
||||
dialog_download = gr.File(
|
||||
label="Download All Files",
|
||||
visible=False
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
add_speaker_btn.click(
|
||||
fn=add_speaker,
|
||||
inputs=[new_speaker_name, new_speaker_audio, speakers_state],
|
||||
outputs=[config_status, speakers_dropdown]
|
||||
)
|
||||
|
||||
remove_speaker_btn.click(
|
||||
fn=remove_speaker,
|
||||
inputs=[speakers_dropdown, speakers_state],
|
||||
outputs=[config_status, speakers_dropdown]
|
||||
)
|
||||
|
||||
# Update the speakers dropdown when the tab is selected
|
||||
def on_tab_change():
|
||||
# Return a dictionary that updates the dropdown choices
|
||||
return {"__type__": "update", "choices": list(speakers.keys())}
|
||||
|
||||
tabs.select(
|
||||
fn=on_tab_change,
|
||||
inputs=[],
|
||||
outputs=[speakers_dropdown]
|
||||
)
|
||||
|
||||
def update_outputs(*args):
|
||||
result = process_dialog(*args)
|
||||
if len(result) == 3:
|
||||
summary, concat_file, zip_file = result
|
||||
if concat_file:
|
||||
# When we have a concatenated file, show both the audio player and download
|
||||
return [
|
||||
summary, # dialog_output
|
||||
gr.Audio(value=concat_file, visible=True), # concatenated_audio
|
||||
gr.File(value=zip_file, visible=True) # dialog_download
|
||||
]
|
||||
# When no concatenated file, just show the zip download
|
||||
return [
|
||||
summary,
|
||||
gr.Audio(visible=False),
|
||||
gr.File(value=zip_file, visible=True)
|
||||
]
|
||||
# Error case
|
||||
return [
|
||||
result[0], # error message
|
||||
gr.Audio(visible=False),
|
||||
gr.File(visible=False)
|
||||
]
|
||||
|
||||
# Update the click handler with the correct number of outputs
|
||||
dialog_generate_btn.click(
|
||||
fn=update_outputs,
|
||||
inputs=[
|
||||
dialog_text,
|
||||
speakers_state, # Pass the current speakers dict
|
||||
output_base,
|
||||
reinit_each_line
|
||||
],
|
||||
outputs=[
|
||||
dialog_output, # Text output
|
||||
concatenated_audio, # Audio component
|
||||
dialog_download # Zip file
|
||||
]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(share=True)
|
||||
|
|
Loading…
Reference in New Issue