Compare commits
No commits in common. "648aedd99d21452a5b6d08dbd89f2467df7984e1" and "769daab4c7cf4fcff391051e2bece6d2742490ee" have entirely different histories.
648aedd99d
...
769daab4c7
|
@ -0,0 +1,7 @@
|
|||
.venv
|
||||
.gradio
|
||||
output*.wav
|
||||
*.wav
|
||||
*.mp3
|
||||
dialog_output/
|
||||
*.zip
|
|
@ -0,0 +1,89 @@
|
|||
# Chatterbox Dialog Generator
|
||||
|
||||
This tool generates audio files for dialog from a markdown file, using the Chatterbox TTS system. It maps speaker names to audio samples using a YAML configuration file.
|
||||
|
||||
## Features
|
||||
|
||||
- Maps speaker names to audio samples via a YAML config file
|
||||
- Processes markdown dialog files with lines in the format: `Name: "Text"`
|
||||
- Generates sequentially numbered audio files (e.g., `001-output.wav`, `002-output.wav`)
|
||||
- Automatically splits long dialog lines (>300 characters) at sentence boundaries
|
||||
- Provides a summary of generated files
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.6+
|
||||
- PyYAML
|
||||
- torchaudio
|
||||
- Chatterbox TTS library
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python cbx-dialog-generate.py --config speakers.yaml --dialog sample-dialog.md --output-base output
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
- `--config`: Path to the YAML config file mapping speaker names to audio samples
|
||||
- `--dialog`: Path to the markdown dialog file
|
||||
- `--output-base`: Base name for output files (e.g., "output" for "001-output.wav")
|
||||
- `--reinit-each-line`: Re-initialize the model after each line to reduce memory usage (useful for long dialogs)
|
||||
|
||||
## Config File Format (YAML)
|
||||
|
||||
The config file maps speaker names (as they appear in the dialog) to audio sample files:
|
||||
|
||||
```yaml
|
||||
Denise: denise.wav
|
||||
Mark: mark.wav
|
||||
Mary: mary.wav
|
||||
```
|
||||
|
||||
## Dialog File Format (Markdown)
|
||||
|
||||
The dialog file should contain lines in the format:
|
||||
|
||||
```
|
||||
Name: "Text"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
Denise: "What do you think is wrong with me?"
|
||||
Mark: "I think you're being overly emotional."
|
||||
Mary: "Jesus, Mark, can you be any more of an asshole?"
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The script generates sequentially numbered WAV files:
|
||||
|
||||
- `001-output.wav`
|
||||
- `002-output.wav`
|
||||
- etc.
|
||||
|
||||
If a dialog line exceeds 300 characters, it will be split at sentence boundaries into multiple files, each maintaining the sequential numbering.
|
||||
|
||||
## Example
|
||||
|
||||
Given the sample dialog and config files, running:
|
||||
|
||||
```bash
|
||||
python cbx-dialog-generate.py --config speakers.yaml --dialog sample-dialog.md --output-base output
|
||||
```
|
||||
|
||||
For long dialogs where memory usage is a concern, you can use:
|
||||
|
||||
```bash
|
||||
python cbx-dialog-generate.py --config speakers.yaml --dialog sample-dialog.md --output-base output --reinit-each-line
|
||||
```
|
||||
|
||||
Either command would generate:
|
||||
- `001-output.wav` - Denise's first line
|
||||
- `002-output.wav` - Mark's first line
|
||||
- `003-output.wav` - Mary's line
|
||||
- `004-output.wav` - First part of Denise's long line
|
||||
- `005-output.wav` - Second part of Denise's long line
|
||||
- `006-output.wav` - Mark's second line
|
73
README.md
73
README.md
|
@ -1,2 +1,73 @@
|
|||
# chatterbox-ui
|
||||
# Chatterbox TTS Gradio App
|
||||
|
||||
This Gradio application provides a user interface for text-to-speech generation using the Chatterbox TTS model. It supports both single utterance generation and multi-speaker dialog generation with configurable silence gaps.
|
||||
|
||||
## Features
|
||||
|
||||
- **Single Utterance Generation**: Generate speech from text using a selected speaker
|
||||
- **Dialog Generation**: Create multi-speaker conversations with configurable silence gaps
|
||||
- **Speaker Management**: Add/remove speakers with custom audio samples
|
||||
- **Memory Optimization**: Automatic model cleanup after generation
|
||||
- **Output Organization**: Files saved in `single_output/` and `dialog_output/` directories
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/your-username/chatterbox-test.git
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Prepare speaker samples:
|
||||
- Create a `speaker_samples/` directory
|
||||
- Add audio samples (WAV format) for each speaker
|
||||
- Update `speakers.yaml` with speaker names and file paths
|
||||
|
||||
4. Run the app:
|
||||
```bash
|
||||
python gradio_app.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Single Utterance Tab
|
||||
- Select a speaker from the dropdown
|
||||
- Enter text to synthesize
|
||||
- Adjust generation parameters as needed
|
||||
- Click "Generate Speech"
|
||||
|
||||
### Dialog Generation Tab
|
||||
1. Add speakers using the speaker configuration section
|
||||
2. Enter dialog in the format:
|
||||
```
|
||||
Speaker1: "Hello, how are you?"
|
||||
Speaker2: "I'm doing well!"
|
||||
Silence: 0.5
|
||||
Speaker1: "What are your plans for today?"
|
||||
```
|
||||
3. Set output base name
|
||||
4. Click "Generate Dialog"
|
||||
|
||||
## File Organization
|
||||
|
||||
- Generated single utterances are saved to `single_output/`
|
||||
- Dialog generation files are saved to `dialog_output/`
|
||||
- Concatenated dialog files have `_concatenated.wav` suffix
|
||||
- All files are zipped together for download
|
||||
|
||||
## Memory Management
|
||||
|
||||
The app automatically:
|
||||
- Cleans up the TTS model after each generation
|
||||
- Frees GPU memory (for CUDA/MPS devices)
|
||||
- Deletes intermediate tensors to minimize memory footprint
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **"Skipping unknown speaker"**: Add the speaker first using the speaker configuration
|
||||
- **"Sample file not found"**: Verify the audio file exists in `speaker_samples/`
|
||||
- **Memory issues**: Try enabling "Re-initialize model each line" for long dialogs
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import re
|
||||
import os
|
||||
import yaml
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
def split_text_at_sentence_boundaries(text, max_length=300):
|
||||
"""
|
||||
Split text at sentence boundaries, ensuring each chunk is <= max_length.
|
||||
Returns a list of text chunks.
|
||||
"""
|
||||
# Simple regex for sentence boundaries (period, question mark, exclamation mark followed by space or end)
|
||||
sentence_pattern = r'[.!?](?:\s|$)'
|
||||
sentences = re.split(f'({sentence_pattern})', text)
|
||||
|
||||
# Recombine the split parts (the regex split keeps the delimiters as separate items)
|
||||
actual_sentences = []
|
||||
current = ""
|
||||
for i in range(0, len(sentences), 2):
|
||||
if i+1 < len(sentences):
|
||||
current = sentences[i] + sentences[i+1]
|
||||
else:
|
||||
current = sentences[i]
|
||||
if current:
|
||||
actual_sentences.append(current)
|
||||
|
||||
# Group sentences into chunks <= max_length
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for sentence in actual_sentences:
|
||||
# If adding this sentence would exceed max_length and we already have content,
|
||||
# finish the current chunk and start a new one
|
||||
if len(current_chunk) + len(sentence) > max_length and current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence
|
||||
else:
|
||||
current_chunk += sentence
|
||||
|
||||
# Add the last chunk if it has content
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def parse_dialog_line(line):
|
||||
"""
|
||||
Parse a dialog line in the format: Name: "Text"
|
||||
Returns a tuple of (speaker, text) or None if the line doesn't match the pattern.
|
||||
"""
|
||||
pattern = r'^([^:]+):\s*"([^"]+)"$'
|
||||
match = re.match(pattern, line.strip())
|
||||
if match:
|
||||
speaker = match.group(1).strip()
|
||||
text = match.group(2).strip()
|
||||
return (speaker, text)
|
||||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate dialog audio from markdown file using Chatterbox TTS")
|
||||
parser.add_argument('--config', required=True, type=str, help='YAML config file mapping speaker names to audio samples')
|
||||
parser.add_argument('--dialog', required=True, type=str, help='Markdown dialog file')
|
||||
parser.add_argument('--output-base', required=True, type=str, help='Base name for output files (e.g., "output" for "001-output.wav")')
|
||||
parser.add_argument('--reinit-each-line', action='store_true', help='Re-initialize the model after each line to reduce memory usage')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the YAML config
|
||||
with open(args.config, 'r') as f:
|
||||
speaker_samples = yaml.safe_load(f)
|
||||
|
||||
# Load the dialog file
|
||||
with open(args.dialog, 'r') as f:
|
||||
dialog_lines = f.readlines()
|
||||
|
||||
# Initialize model only once if not reinitializing per line
|
||||
model = None
|
||||
if not args.reinit_each_line:
|
||||
print("Loading ChatterboxTTS model once for all lines...")
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
# Process each dialog line
|
||||
file_counter = 1
|
||||
summary = []
|
||||
|
||||
for line_num, line in enumerate(dialog_lines, 1):
|
||||
parsed = parse_dialog_line(line)
|
||||
if not parsed:
|
||||
print(f"Skipping line {line_num}: Not in the expected format")
|
||||
continue
|
||||
|
||||
speaker, text = parsed
|
||||
|
||||
# Check if the speaker is in the config
|
||||
if speaker not in speaker_samples:
|
||||
print(f"Warning: Speaker '{speaker}' not found in config, skipping line {line_num}")
|
||||
continue
|
||||
|
||||
sample_path = speaker_samples[speaker]
|
||||
|
||||
# Reinitialize model if needed
|
||||
if args.reinit_each_line or model is None:
|
||||
if args.reinit_each_line:
|
||||
print(f"Reinitializing model for line {line_num}...")
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
# Check if the text needs to be split (> 300 chars)
|
||||
if len(text) > 300:
|
||||
chunks = split_text_at_sentence_boundaries(text)
|
||||
chunk_files = []
|
||||
|
||||
for chunk in chunks:
|
||||
output_file = f"{file_counter:03d}-{args.output_base}.wav"
|
||||
|
||||
# Generate audio for this chunk
|
||||
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
|
||||
chunk_files.append(output_file)
|
||||
summary.append(f"File {output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
||||
file_counter += 1
|
||||
|
||||
print(f"Generated {len(chunks)} files for line {line_num} (speaker: {speaker})")
|
||||
else:
|
||||
# Generate a single file for this line
|
||||
output_file = f"{file_counter:03d}-{args.output_base}.wav"
|
||||
|
||||
# Generate audio
|
||||
wav = model.generate(text, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
|
||||
summary.append(f"File {output_file}: {speaker} - {text[:50]}...")
|
||||
file_counter += 1
|
||||
print(f"Generated file for line {line_num} (speaker: {speaker})")
|
||||
|
||||
# Print summary
|
||||
print("\nSummary of generated files:")
|
||||
for entry in summary:
|
||||
print(entry)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,22 @@
|
|||
import argparse
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Chatterbox TTS audio generation")
|
||||
parser.add_argument('--sample', required=True, type=str, help='Prompt/reference audio file (e.g. .wav, .mp3) for the voice')
|
||||
parser.add_argument('--output', required=True, type=str, help='Output audio file path (should end with .wav)')
|
||||
parser.add_argument('--text', required=True, type=str, help='Text to synthesize')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model on MPS (for Apple Silicon)
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
# Generate the audio
|
||||
wav = model.generate(args.text, audio_prompt_path=args.sample)
|
||||
# Save to output .wav
|
||||
ta.save(args.output, wav, model.sr)
|
||||
print(f"Generated audio saved to {args.output}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,13 @@
|
|||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
text = "Sometimes you have to wonder just what's going on with this crazy fucking world."
|
||||
#wav = model.generate(text)
|
||||
#ta.save("test-1.wav", wav, model.sr)
|
||||
|
||||
# If you want to synthesize with a different voice, specify the audio prompt
|
||||
AUDIO_PROMPT_PATH="sample.mp3"
|
||||
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
|
||||
ta.save("test-2.wav", wav, model.sr)
|
|
@ -0,0 +1,244 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import perth
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.t3 import T3
|
||||
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
||||
from .models.s3gen import S3GEN_SR, S3Gen
|
||||
from .models.tokenizers import EnTokenizer
|
||||
from .models.voice_encoder import VoiceEncoder
|
||||
from .models.t3.modules.cond_enc import T3Cond
|
||||
|
||||
|
||||
REPO_ID = "ResembleAI/chatterbox"
|
||||
|
||||
|
||||
def punc_norm(text: str) -> str:
|
||||
"""
|
||||
Quick cleanup func for punctuation from LLMs or
|
||||
containing chars not seen often in the dataset
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return "You need to add some text for me to talk."
|
||||
|
||||
# Capitalise first letter
|
||||
if text[0].islower():
|
||||
text = text[0].upper() + text[1:]
|
||||
|
||||
# Remove multiple space chars
|
||||
text = " ".join(text.split())
|
||||
|
||||
# Replace uncommon/llm punc
|
||||
punc_to_replace = [
|
||||
("...", ", "),
|
||||
("…", ", "),
|
||||
(":", ","),
|
||||
(" - ", ", "),
|
||||
(";", ", "),
|
||||
("—", "-"),
|
||||
("–", "-"),
|
||||
(" ,", ","),
|
||||
("“", "\""),
|
||||
("”", "\""),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
]
|
||||
for old_char_sequence, new_char in punc_to_replace:
|
||||
text = text.replace(old_char_sequence, new_char)
|
||||
|
||||
# Add full stop if no ending punc
|
||||
text = text.rstrip(" ")
|
||||
sentence_enders = {".", "!", "?", "-", ","}
|
||||
if not any(text.endswith(p) for p in sentence_enders):
|
||||
text += "."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conditionals:
|
||||
"""
|
||||
Conditionals for T3 and S3Gen
|
||||
- T3 conditionals:
|
||||
- speaker_emb
|
||||
- clap_emb
|
||||
- cond_prompt_speech_tokens
|
||||
- cond_prompt_speech_emb
|
||||
- emotion_adv
|
||||
- S3Gen conditionals:
|
||||
- prompt_token
|
||||
- prompt_token_len
|
||||
- prompt_feat
|
||||
- prompt_feat_len
|
||||
- embedding
|
||||
"""
|
||||
t3: T3Cond
|
||||
gen: dict
|
||||
|
||||
def to(self, device):
|
||||
self.t3 = self.t3.to(device=device)
|
||||
for k, v in self.gen.items():
|
||||
if torch.is_tensor(v):
|
||||
self.gen[k] = v.to(device=device)
|
||||
return self
|
||||
|
||||
def save(self, fpath: Path):
|
||||
arg_dict = dict(
|
||||
t3=self.t3.__dict__,
|
||||
gen=self.gen
|
||||
)
|
||||
torch.save(arg_dict, fpath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, fpath, map_location="cpu"):
|
||||
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
||||
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
|
||||
|
||||
|
||||
class ChatterboxTTS:
|
||||
ENC_COND_LEN = 6 * S3_SR
|
||||
DEC_COND_LEN = 10 * S3GEN_SR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t3: T3,
|
||||
s3gen: S3Gen,
|
||||
ve: VoiceEncoder,
|
||||
tokenizer: EnTokenizer,
|
||||
device: str,
|
||||
conds: Conditionals = None,
|
||||
):
|
||||
self.sr = S3GEN_SR # sample rate of synthesized audio
|
||||
self.t3 = t3
|
||||
self.s3gen = s3gen
|
||||
self.ve = ve
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.conds = conds
|
||||
self.watermarker = perth.PerthImplicitWatermarker()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
|
||||
ckpt_dir = Path(ckpt_dir)
|
||||
|
||||
ve = VoiceEncoder()
|
||||
ve.load_state_dict(
|
||||
torch.load(ckpt_dir / "ve.pt")
|
||||
)
|
||||
ve.to(device).eval()
|
||||
|
||||
t3 = T3()
|
||||
t3_state = torch.load(ckpt_dir / "t3_cfg.pt")
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
t3.to(device).eval()
|
||||
|
||||
s3gen = S3Gen()
|
||||
s3gen.load_state_dict(
|
||||
torch.load(ckpt_dir / "s3gen.pt")
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
tokenizer = EnTokenizer(
|
||||
str(ckpt_dir / "tokenizer.json")
|
||||
)
|
||||
|
||||
conds = None
|
||||
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
||||
conds = Conditionals.load(builtin_voice).to(device)
|
||||
|
||||
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, device) -> 'ChatterboxTTS':
|
||||
for fpath in ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]:
|
||||
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
|
||||
|
||||
return cls.from_local(Path(local_path).parent, device)
|
||||
|
||||
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
||||
## Load reference wav
|
||||
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
||||
|
||||
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
||||
|
||||
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
||||
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|
||||
|
||||
# Speech cond prompt tokens
|
||||
if plen := self.t3.hp.speech_cond_prompt_len:
|
||||
s3_tokzr = self.s3gen.tokenizer
|
||||
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
|
||||
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
|
||||
|
||||
# Voice-encoder speaker embedding
|
||||
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
|
||||
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
||||
|
||||
t3_cond = T3Cond(
|
||||
speaker_emb=ve_embed,
|
||||
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text,
|
||||
audio_prompt_path=None,
|
||||
exaggeration=0.5,
|
||||
cfg_weight=0.5,
|
||||
temperature=0.8,
|
||||
):
|
||||
if audio_prompt_path:
|
||||
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
|
||||
else:
|
||||
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
|
||||
|
||||
# Update exaggeration if needed
|
||||
if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]:
|
||||
_cond: T3Cond = self.conds.t3
|
||||
self.conds.t3 = T3Cond(
|
||||
speaker_emb=_cond.speaker_emb,
|
||||
cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
|
||||
# Norm and tokenize text
|
||||
text = punc_norm(text)
|
||||
text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
|
||||
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
|
||||
|
||||
sot = self.t3.hp.start_text_token
|
||||
eot = self.t3.hp.stop_text_token
|
||||
text_tokens = F.pad(text_tokens, (1, 0), value=sot)
|
||||
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
|
||||
|
||||
with torch.inference_mode():
|
||||
speech_tokens = self.t3.inference(
|
||||
t3_cond=self.conds.t3,
|
||||
text_tokens=text_tokens,
|
||||
max_new_tokens=1000, # TODO: use the value in config
|
||||
temperature=temperature,
|
||||
cfg_weight=cfg_weight,
|
||||
)
|
||||
# Extract only the conditional batch.
|
||||
speech_tokens = speech_tokens[0]
|
||||
|
||||
# TODO: output becomes 1D
|
||||
speech_tokens = drop_invalid_tokens(speech_tokens)
|
||||
speech_tokens = speech_tokens.to(self.device)
|
||||
|
||||
wav, _ = self.s3gen.inference(
|
||||
speech_tokens=speech_tokens,
|
||||
ref_dict=self.conds.gen,
|
||||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
|
@ -0,0 +1,665 @@
|
|||
import gradio as gr
|
||||
import yaml
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
import torchaudio as ta
|
||||
import torch
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
|
||||
# Load speaker options from YAML with error handling
|
||||
try:
|
||||
yaml_path = os.path.abspath("speakers.yaml")
|
||||
if not os.path.exists(yaml_path):
|
||||
raise FileNotFoundError(f"speakers.yaml not found at {yaml_path}")
|
||||
|
||||
with open(yaml_path) as f:
|
||||
speakers = yaml.safe_load(f)
|
||||
|
||||
if not speakers or not isinstance(speakers, dict):
|
||||
raise ValueError("speakers.yaml must contain a valid dictionary mapping")
|
||||
|
||||
except Exception as e:
|
||||
raise SystemExit(f"Failed to load speakers.yaml: {str(e)}")
|
||||
|
||||
def split_text_at_sentence_boundaries(text, max_length=300):
|
||||
"""Split text at sentence boundaries, ensuring each chunk is <= max_length."""
|
||||
sentence_pattern = r'[.!?](?:\s|$)'
|
||||
sentences = re.split(f'({sentence_pattern})', text)
|
||||
|
||||
actual_sentences = []
|
||||
current = ""
|
||||
for i in range(0, len(sentences), 2):
|
||||
if i+1 < len(sentences):
|
||||
current = sentences[i] + sentences[i+1]
|
||||
else:
|
||||
current = sentences[i]
|
||||
if current:
|
||||
actual_sentences.append(current)
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for sentence in actual_sentences:
|
||||
if len(current_chunk) + len(sentence) > max_length and current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence
|
||||
else:
|
||||
current_chunk += sentence
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def parse_dialog_line(line):
|
||||
"""Parse a dialog line to extract speaker and text."""
|
||||
# Handle silence lines
|
||||
silence_match = re.match(r'^Silence:\s*([0-9]*\.?[0-9]+)$', line.strip(), re.IGNORECASE)
|
||||
if silence_match:
|
||||
duration = float(silence_match.group(1))
|
||||
return "SILENCE", duration
|
||||
|
||||
# Handle regular dialog lines
|
||||
match = re.match(r'^([^:]+):\s*["\'](.+)["\']$', line.strip())
|
||||
if match:
|
||||
speaker = match.group(1).strip()
|
||||
text = match.group(2).strip()
|
||||
return speaker, text
|
||||
return None
|
||||
|
||||
def save_speakers_config(speakers_dict: Dict) -> str:
|
||||
"""Save speakers configuration to YAML file."""
|
||||
try:
|
||||
with open("speakers.yaml", 'w') as f:
|
||||
yaml.dump(speakers_dict, f)
|
||||
return "Speakers configuration saved successfully!"
|
||||
except Exception as e:
|
||||
return f"Error saving configuration: {str(e)}"
|
||||
|
||||
def add_speaker(speaker_name: str, audio_file: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
||||
"""Add a new speaker with audio sample."""
|
||||
if not speaker_name or not audio_file:
|
||||
return "Please provide both speaker name and audio file", current_speakers
|
||||
|
||||
if speaker_name in current_speakers:
|
||||
return f"Speaker '{speaker_name}' already exists!", current_speakers
|
||||
|
||||
# Save the audio file
|
||||
speakers_dir = "speaker_samples"
|
||||
os.makedirs(speakers_dir, exist_ok=True)
|
||||
|
||||
# Generate a unique filename
|
||||
ext = os.path.splitext(audio_file)[1] or '.wav'
|
||||
new_filename = f"{speaker_name.lower().replace(' ', '_')}{ext}"
|
||||
new_filepath = os.path.join(speakers_dir, new_filename)
|
||||
|
||||
# Copy the uploaded file
|
||||
shutil.copy2(audio_file, new_filepath)
|
||||
|
||||
# Update speakers dictionary
|
||||
updated_speakers = dict(current_speakers)
|
||||
updated_speakers[speaker_name] = new_filepath
|
||||
|
||||
# Save to YAML
|
||||
save_speakers_config(updated_speakers)
|
||||
|
||||
return f"Speaker '{speaker_name}' added successfully!", updated_speakers
|
||||
|
||||
def remove_speaker(speaker_name: str, current_speakers: Dict) -> Tuple[str, Dict]:
|
||||
"""Remove a speaker from the configuration."""
|
||||
if not speaker_name or speaker_name not in current_speakers:
|
||||
return "Speaker not found!", current_speakers
|
||||
|
||||
# Don't actually delete the audio file, just remove from config
|
||||
updated_speakers = dict(current_speakers)
|
||||
del updated_speakers[speaker_name]
|
||||
|
||||
# Save to YAML
|
||||
save_speakers_config(updated_speakers)
|
||||
|
||||
return f"Speaker '{speaker_name}' removed!", updated_speakers
|
||||
|
||||
def update_speakers_dropdown():
|
||||
"""Update the speakers dropdown with current speakers."""
|
||||
return gr.Dropdown.update(choices=list(speakers.keys()))
|
||||
|
||||
def generate_audio(speaker_choice, custom_sample, text, exaggeration, cfg_weight, temperature, max_new_tokens):
|
||||
import torch # Ensure torch is available in function scope
|
||||
tts = None
|
||||
wav = None
|
||||
try:
|
||||
# Get sample path from selection or upload
|
||||
sample_path = speakers[speaker_choice] if speaker_choice != "Custom" else custom_sample
|
||||
|
||||
if not os.path.exists(sample_path):
|
||||
raise gr.Error("Sample file not found!")
|
||||
|
||||
# Load model (cached automatically by Gradio)
|
||||
tts = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
# Generate audio with advanced controls - disable gradients for inference
|
||||
with torch.no_grad():
|
||||
gen_kwargs = dict(
|
||||
text=text,
|
||||
audio_prompt_path=sample_path,
|
||||
exaggeration=exaggeration,
|
||||
cfg_weight=cfg_weight,
|
||||
temperature=temperature
|
||||
)
|
||||
# max_new_tokens is not supported by the current TTS library, so we ignore it here
|
||||
wav = tts.generate(**gen_kwargs)
|
||||
|
||||
# Create output directory for single utterances
|
||||
output_dir = "single_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save with timestamp in the subdirectory
|
||||
output_path = os.path.join(output_dir, f"output_{int(time.time())}.wav")
|
||||
ta.save(output_path, wav, tts.sr)
|
||||
|
||||
return output_path, output_path
|
||||
|
||||
finally:
|
||||
# Clean up audio tensor first
|
||||
if wav is not None:
|
||||
try:
|
||||
del wav
|
||||
except:
|
||||
pass
|
||||
|
||||
# Clean up model to free memory
|
||||
if tts is not None:
|
||||
print("Cleaning up TTS model to free memory...") # Debug log
|
||||
try:
|
||||
# Move model to CPU and delete reference
|
||||
if hasattr(tts, 'cpu'):
|
||||
tts.cpu()
|
||||
del tts
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Clear GPU cache - both CUDA and MPS
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Clear MPS cache for Apple Silicon
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
print("Model cleanup completed") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during model cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
def process_dialog(dialog_text, speaker_samples, output_base, reinit_each_line, progress=gr.Progress()):
|
||||
"""Process dialog text and generate audio files."""
|
||||
import torch # Ensure torch is available in function scope
|
||||
model = None
|
||||
try:
|
||||
print("Starting dialog processing...") # Debug log
|
||||
print(f"Speaker samples: {speaker_samples}") # Debug log
|
||||
|
||||
if not dialog_text or not dialog_text.strip():
|
||||
return "Error: No dialog text provided", None
|
||||
|
||||
# Parse dialog lines
|
||||
dialog_lines = [line.strip() for line in dialog_text.split('\n') if line.strip()]
|
||||
|
||||
if not dialog_lines:
|
||||
return "Error: No valid dialog lines found", None
|
||||
|
||||
print(f"Processing {len(dialog_lines)} dialog lines") # Debug log
|
||||
|
||||
# Initialize model only once if not reinitializing per line
|
||||
if not reinit_each_line:
|
||||
progress(0.1, desc="Loading TTS model...")
|
||||
try:
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
print("TTS model loaded successfully") # Debug log
|
||||
except Exception as e:
|
||||
return f"Error loading TTS model: {str(e)}", None
|
||||
|
||||
# Create output directory
|
||||
output_dir = "dialog_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Process each dialog line
|
||||
file_counter = 1
|
||||
output_files = []
|
||||
silence_durations = [] # Track silence durations for concatenation
|
||||
summary = []
|
||||
|
||||
for i, line in enumerate(dialog_lines):
|
||||
progress(i / len(dialog_lines), desc=f"Processing line {i+1}/{len(dialog_lines)}")
|
||||
print(f"Processing line {i+1}: {line}") # Debug log
|
||||
|
||||
try:
|
||||
parsed = parse_dialog_line(line)
|
||||
if not parsed:
|
||||
print(f"Skipping line (invalid format): {line}") # Debug log
|
||||
continue
|
||||
|
||||
speaker, text_or_duration = parsed
|
||||
print(f"Found speaker: {speaker}, content: {text_or_duration}") # Debug log
|
||||
|
||||
# Handle silence entries
|
||||
if speaker == "SILENCE":
|
||||
duration = text_or_duration
|
||||
print(f"Adding silence duration: {duration} seconds") # Debug log
|
||||
silence_durations.append(duration)
|
||||
summary.append(f"Silence: {duration} seconds")
|
||||
continue
|
||||
|
||||
# Handle regular dialog
|
||||
text = text_or_duration
|
||||
if speaker not in speaker_samples:
|
||||
msg = f"Skipping unknown speaker: {speaker}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
# Add default silence duration if we skip a speaker
|
||||
silence_durations.append(1.0)
|
||||
continue
|
||||
|
||||
sample_path = speaker_samples[speaker]
|
||||
if not os.path.exists(sample_path):
|
||||
msg = f"Audio sample not found for speaker '{speaker}': {sample_path}"
|
||||
print(msg) # Debug log
|
||||
summary.append(msg)
|
||||
# Add default silence duration if we skip due to missing sample
|
||||
silence_durations.append(1.0)
|
||||
continue
|
||||
|
||||
if reinit_each_line or model is None:
|
||||
print("Initializing new TTS model instance") # Debug log
|
||||
model = ChatterboxTTS.from_pretrained(device="mps")
|
||||
|
||||
if len(text) > 300:
|
||||
chunks = split_text_at_sentence_boundaries(text)
|
||||
print(f"Splitting long text into {len(chunks)} chunks") # Debug log
|
||||
for chunk in chunks:
|
||||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio for chunk: {chunk[:50]}...") # Debug log
|
||||
try:
|
||||
with torch.no_grad():
|
||||
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
# Clean up wav tensor immediately
|
||||
del wav
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
||||
file_counter += 1
|
||||
print(f"Generated audio: {output_file}") # Debug log
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating audio for chunk: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
else:
|
||||
output_file = os.path.join(output_dir, f"{file_counter:03d}-{output_base}.wav")
|
||||
print(f"Generating audio: {text[:50]}...") # Debug log
|
||||
try:
|
||||
with torch.no_grad():
|
||||
wav = model.generate(text, audio_prompt_path=sample_path)
|
||||
ta.save(output_file, wav, model.sr)
|
||||
# Clean up wav tensor immediately
|
||||
del wav
|
||||
output_files.append(output_file)
|
||||
summary.append(f"{output_file}: {speaker} - {text[:50]}...")
|
||||
file_counter += 1
|
||||
print(f"Generated audio: {output_file}") # Debug log
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating audio: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing line '{line}': {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
summary.append(error_msg)
|
||||
continue
|
||||
|
||||
if not output_files:
|
||||
return "Error: No audio files were generated. Check speaker names and audio samples.", None
|
||||
|
||||
# Concatenate all audio files with configurable gaps
|
||||
concatenated_file = None
|
||||
waveforms = []
|
||||
concatenated = None
|
||||
try:
|
||||
if len(output_files) > 1:
|
||||
print("Concatenating audio files...") # Debug log
|
||||
# Load all audio files
|
||||
sample_rates = set()
|
||||
|
||||
for file in output_files:
|
||||
waveform, sample_rate = ta.load(file)
|
||||
waveforms.append(waveform)
|
||||
sample_rates.add(sample_rate)
|
||||
|
||||
if len(sample_rates) != 1:
|
||||
raise ValueError(f"Sample rate mismatch: {sample_rates}")
|
||||
|
||||
sample_rate = sample_rates.pop()
|
||||
|
||||
# Start with the first audio segment
|
||||
concatenated = waveforms[0]
|
||||
|
||||
# Add subsequent segments with appropriate silence gaps
|
||||
for i, wav in enumerate(waveforms[1:], 1):
|
||||
# Determine silence duration for this gap
|
||||
if i-1 < len(silence_durations):
|
||||
gap_duration = silence_durations[i-1]
|
||||
else:
|
||||
gap_duration = 1.0 # Default 1 second
|
||||
|
||||
gap_samples = int(gap_duration * sample_rate)
|
||||
gap = torch.zeros(1, gap_samples) # Mono channel
|
||||
|
||||
print(f"Adding {gap_duration}s gap before segment {i+1}") # Debug log
|
||||
concatenated = torch.cat([concatenated, gap, wav], dim=1)
|
||||
|
||||
# Clean up gap tensor immediately
|
||||
del gap
|
||||
|
||||
# Save concatenated file
|
||||
concatenated_path = os.path.join(output_dir, f"{output_base}_concatenated.wav")
|
||||
ta.save(concatenated_path, concatenated, sample_rate)
|
||||
output_files.append(concatenated_path)
|
||||
summary.append(f"\nConcatenated file: {concatenated_path}")
|
||||
concatenated_file = concatenated_path
|
||||
print(f"Created concatenated file: {concatenated_path}")
|
||||
|
||||
# Create a zip file of all outputs
|
||||
import zipfile
|
||||
zip_path = os.path.join(output_dir, f"{output_base}.zip")
|
||||
print(f"Creating zip file: {zip_path}") # Debug log
|
||||
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
||||
for file in output_files:
|
||||
zipf.write(file, os.path.basename(file))
|
||||
print(f"Zip file created successfully with {len(output_files)} files") # Debug log
|
||||
|
||||
# Return both the zip and the concatenated file if it exists
|
||||
if concatenated_file:
|
||||
return "\n".join(summary), concatenated_file, zip_path
|
||||
return "\n".join(summary), None, zip_path
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating zip file: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
return "\n".join(summary + [error_msg]), None
|
||||
finally:
|
||||
# Clean up concatenation tensors
|
||||
try:
|
||||
if waveforms:
|
||||
for waveform in waveforms:
|
||||
del waveform
|
||||
waveforms.clear()
|
||||
if concatenated is not None:
|
||||
del concatenated
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("Concatenation tensors cleaned up") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during concatenation cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error: {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
import traceback
|
||||
traceback.print_exc() # Print full traceback
|
||||
return error_msg, None
|
||||
|
||||
finally:
|
||||
# Clean up model to free memory
|
||||
if 'model' in locals() and model is not None:
|
||||
print("Cleaning up TTS model to free memory...") # Debug log
|
||||
try:
|
||||
# Move model to CPU and delete reference
|
||||
if hasattr(model, 'cpu'):
|
||||
model.cpu()
|
||||
del model
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Clear GPU cache - both CUDA and MPS
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Clear MPS cache for Apple Silicon
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
print("Model cleanup completed") # Debug log
|
||||
except Exception as cleanup_error:
|
||||
print(f"Error during model cleanup: {cleanup_error}") # Debug log
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# Chatterbox TTS Generator")
|
||||
|
||||
# Store speakers in a global state
|
||||
speakers_state = gr.State(speakers)
|
||||
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Single Utterance"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
speaker_dropdown = gr.Dropdown(
|
||||
choices=["Custom"] + list(speakers.keys()),
|
||||
value="Custom",
|
||||
label="Select Speaker"
|
||||
)
|
||||
custom_upload = gr.Audio(
|
||||
label="Or upload custom speaker sample",
|
||||
type="filepath",
|
||||
visible=True
|
||||
)
|
||||
text_input = gr.Textbox(
|
||||
label="Text to synthesize",
|
||||
placeholder="Enter text here...",
|
||||
lines=3
|
||||
)
|
||||
exaggeration_slider = gr.Slider(
|
||||
minimum=0.0, maximum=2.0, value=0.5, step=0.01,
|
||||
label="Exaggeration (emotion)",
|
||||
info="Controls expressiveness. 0.5 = neutral, higher = more expressive."
|
||||
)
|
||||
cfg_weight_slider = gr.Slider(
|
||||
minimum=0.0, maximum=2.0, value=0.5, step=0.01,
|
||||
label="CFG Weight",
|
||||
info="Higher = more faithful to text, lower = more like reference voice."
|
||||
)
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.1, maximum=2.0, value=0.8, step=0.01,
|
||||
label="Temperature",
|
||||
info="Controls randomness. Higher = more variation."
|
||||
)
|
||||
max_new_tokens_box = gr.Number(
|
||||
value=1000,
|
||||
label="Max New Tokens (advanced)",
|
||||
precision=0,
|
||||
info="Maximum audio tokens to generate. Increase for longer texts."
|
||||
)
|
||||
generate_btn = gr.Button("Generate Speech")
|
||||
|
||||
with gr.Column():
|
||||
audio_output = gr.Audio(label="Generated Speech")
|
||||
download = gr.File(label="Download WAV")
|
||||
|
||||
gr.Examples(
|
||||
examples=[
|
||||
["Hello world! This is a demo.", "Tara"],
|
||||
["Welcome to the future of text-to-speech.", "Zac"]
|
||||
],
|
||||
inputs=[text_input, speaker_dropdown]
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=generate_audio,
|
||||
inputs=[speaker_dropdown, custom_upload, text_input, exaggeration_slider,
|
||||
cfg_weight_slider, temperature_slider, max_new_tokens_box],
|
||||
outputs=[audio_output, download]
|
||||
)
|
||||
|
||||
with gr.TabItem("Dialog Generation"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
dialog_text = gr.Textbox(
|
||||
label="Dialog Text",
|
||||
placeholder='''Enter dialog in format:
|
||||
Speaker1: "Hello, how are you?"
|
||||
Speaker2: "I'm doing well, thank you!"
|
||||
Speaker1: "What are your plans for today?"
|
||||
Speaker2: "I'm working on a new project."''',
|
||||
lines=10
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown("### Speaker Configuration")
|
||||
with gr.Row():
|
||||
new_speaker_name = gr.Textbox(
|
||||
label="New Speaker Name",
|
||||
placeholder="Enter speaker name"
|
||||
)
|
||||
new_speaker_audio = gr.Audio(
|
||||
label="Speaker Sample",
|
||||
type="filepath"
|
||||
)
|
||||
with gr.Row():
|
||||
add_speaker_btn = gr.Button("Add Speaker")
|
||||
remove_speaker_btn = gr.Button("Remove Selected")
|
||||
|
||||
speakers_dropdown = gr.Dropdown(
|
||||
label="Available Speakers",
|
||||
choices=list(speakers.keys()) if speakers else [],
|
||||
interactive=True,
|
||||
multiselect=True
|
||||
)
|
||||
|
||||
gr.Markdown("### Generation Settings")
|
||||
with gr.Row():
|
||||
output_base = gr.Textbox(
|
||||
label="Output Base Name",
|
||||
value="dialog_output",
|
||||
placeholder="base_name (will generate 001-base_name.wav, etc.)"
|
||||
)
|
||||
reinit_each_line = gr.Checkbox(
|
||||
label="Re-initialize model each line",
|
||||
value=False,
|
||||
info="Reduces memory usage but is slower"
|
||||
)
|
||||
|
||||
config_status = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
visible=True
|
||||
)
|
||||
|
||||
dialog_generate_btn = gr.Button("Generate Dialog")
|
||||
|
||||
with gr.Column():
|
||||
dialog_output = gr.Textbox(
|
||||
label="Generation Log",
|
||||
interactive=False,
|
||||
lines=15
|
||||
)
|
||||
concatenated_audio = gr.Audio(
|
||||
label="Concatenated Audio",
|
||||
visible=False
|
||||
)
|
||||
dialog_download = gr.File(
|
||||
label="Download All Files",
|
||||
visible=False
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
add_speaker_btn.click(
|
||||
fn=add_speaker,
|
||||
inputs=[new_speaker_name, new_speaker_audio, speakers_state],
|
||||
outputs=[config_status, speakers_state]
|
||||
)
|
||||
|
||||
remove_speaker_btn.click(
|
||||
fn=remove_speaker,
|
||||
inputs=[speakers_dropdown, speakers_state],
|
||||
outputs=[config_status, speakers_state]
|
||||
)
|
||||
|
||||
# Update the speakers dropdown when speakers_state changes
|
||||
def update_dropdown(speakers_dict):
|
||||
return gr.Dropdown(choices=list(speakers_dict.keys()) if speakers_dict else [])
|
||||
|
||||
speakers_state.change(
|
||||
fn=update_dropdown,
|
||||
inputs=[speakers_state],
|
||||
outputs=[speakers_dropdown]
|
||||
)
|
||||
|
||||
# Update the speakers dropdown when the tab is selected
|
||||
def on_tab_change():
|
||||
# Return a dictionary that updates the dropdown choices
|
||||
return {"__type__": "update", "choices": list(speakers.keys())}
|
||||
|
||||
tabs.select(
|
||||
fn=on_tab_change,
|
||||
inputs=[],
|
||||
outputs=[speakers_dropdown]
|
||||
)
|
||||
|
||||
def update_outputs(*args):
|
||||
result = process_dialog(*args)
|
||||
if len(result) == 3:
|
||||
summary, concat_file, zip_file = result
|
||||
if concat_file:
|
||||
# When we have a concatenated file, show both the audio player and download
|
||||
return [
|
||||
summary, # dialog_output
|
||||
gr.Audio(value=concat_file, visible=True), # concatenated_audio
|
||||
gr.File(value=zip_file, visible=True) # dialog_download
|
||||
]
|
||||
# When no concatenated file, just show the zip download
|
||||
return [
|
||||
summary,
|
||||
gr.Audio(visible=False),
|
||||
gr.File(value=zip_file, visible=True)
|
||||
]
|
||||
# Error case
|
||||
return [
|
||||
result[0], # error message
|
||||
gr.Audio(visible=False),
|
||||
gr.File(visible=False)
|
||||
]
|
||||
|
||||
# Update the click handler with the correct number of outputs
|
||||
dialog_generate_btn.click(
|
||||
fn=update_outputs,
|
||||
inputs=[
|
||||
dialog_text,
|
||||
speakers_state, # Pass the current speakers dict
|
||||
output_base,
|
||||
reinit_each_line
|
||||
],
|
||||
outputs=[
|
||||
dialog_output, # Text output
|
||||
concatenated_audio, # Audio component
|
||||
dialog_download # Zip file
|
||||
]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(share=True)
|
|
@ -0,0 +1,5 @@
|
|||
Leah: "What do you think is wrong with me?"
|
||||
Zac: "I think you're being overly emotional."
|
||||
Tara: "Jesus, Mark, can you be any more of an asshole?"
|
||||
Leah: "This is a longer line that will demonstrate how the script handles text that exceeds the 300 character limit. It will be split at sentence boundaries to ensure that the generated audio files are of a reasonable length. This sentence adds more characters. And this one adds even more to push us over the 300 character limit. The script should create multiple audio files for this single dialog line, while keeping the sentence structure intact."
|
||||
Zac: "I didn't mean to upset anyone. I was just trying to be honest."
|
|
@ -0,0 +1,11 @@
|
|||
Adam: Adam.mp3
|
||||
Alice: Alice.mp3
|
||||
David: speaker_samples/david.mp3
|
||||
Debbie: speaker_samples/debbie.mp3
|
||||
Denise: speaker_samples/denise.mp3
|
||||
Leah: Leah.mp3
|
||||
Leo: Leo.mp3
|
||||
Lewis: Lewis.mp3
|
||||
Maria: speaker_samples/maria.mp3
|
||||
Tara: Tara.mp3
|
||||
Zac: Zac.mp3
|
Loading…
Reference in New Issue