chatterbox-ui/cbx-dialog-generate.py

147 lines
5.3 KiB
Python

#!/usr/bin/env python3
import argparse
import re
import os
import yaml
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
# Import helper to fix Python path
import import_helper
def split_text_at_sentence_boundaries(text, max_length=300):
"""
Split text at sentence boundaries, ensuring each chunk is <= max_length.
Returns a list of text chunks.
"""
# Simple regex for sentence boundaries (period, question mark, exclamation mark followed by space or end)
sentence_pattern = r'[.!?](?:\s|$)'
sentences = re.split(f'({sentence_pattern})', text)
# Recombine the split parts (the regex split keeps the delimiters as separate items)
actual_sentences = []
current = ""
for i in range(0, len(sentences), 2):
if i+1 < len(sentences):
current = sentences[i] + sentences[i+1]
else:
current = sentences[i]
if current:
actual_sentences.append(current)
# Group sentences into chunks <= max_length
chunks = []
current_chunk = ""
for sentence in actual_sentences:
# If adding this sentence would exceed max_length and we already have content,
# finish the current chunk and start a new one
if len(current_chunk) + len(sentence) > max_length and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
else:
current_chunk += sentence
# Add the last chunk if it has content
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def parse_dialog_line(line):
"""
Parse a dialog line in the format: Name: "Text"
Returns a tuple of (speaker, text) or None if the line doesn't match the pattern.
"""
pattern = r'^([^:]+):\s*"([^"]+)"$'
match = re.match(pattern, line.strip())
if match:
speaker = match.group(1).strip()
text = match.group(2).strip()
return (speaker, text)
return None
def main():
parser = argparse.ArgumentParser(description="Generate dialog audio from markdown file using Chatterbox TTS")
parser.add_argument('--config', required=True, type=str, help='YAML config file mapping speaker names to audio samples')
parser.add_argument('--dialog', required=True, type=str, help='Markdown dialog file')
parser.add_argument('--output-base', required=True, type=str, help='Base name for output files (e.g., "output" for "001-output.wav")')
parser.add_argument('--reinit-each-line', action='store_true', help='Re-initialize the model after each line to reduce memory usage')
args = parser.parse_args()
# Load the YAML config
with open(args.config, 'r') as f:
speaker_samples = yaml.safe_load(f)
# Load the dialog file
with open(args.dialog, 'r') as f:
dialog_lines = f.readlines()
# Initialize model only once if not reinitializing per line
model = None
if not args.reinit_each_line:
print("Loading ChatterboxTTS model once for all lines...")
model = ChatterboxTTS.from_pretrained(device="mps")
# Process each dialog line
file_counter = 1
summary = []
for line_num, line in enumerate(dialog_lines, 1):
parsed = parse_dialog_line(line)
if not parsed:
print(f"Skipping line {line_num}: Not in the expected format")
continue
speaker, text = parsed
# Check if the speaker is in the config
if speaker not in speaker_samples:
print(f"Warning: Speaker '{speaker}' not found in config, skipping line {line_num}")
continue
sample_path = speaker_samples[speaker]
# Reinitialize model if needed
if args.reinit_each_line or model is None:
if args.reinit_each_line:
print(f"Reinitializing model for line {line_num}...")
model = ChatterboxTTS.from_pretrained(device="mps")
# Check if the text needs to be split (> 300 chars)
if len(text) > 300:
chunks = split_text_at_sentence_boundaries(text)
chunk_files = []
for chunk in chunks:
output_file = f"{file_counter:03d}-{args.output_base}.wav"
# Generate audio for this chunk
wav = model.generate(chunk, audio_prompt_path=sample_path)
ta.save(output_file, wav, model.sr)
chunk_files.append(output_file)
summary.append(f"File {output_file}: {speaker} (chunk) - {chunk[:50]}...")
file_counter += 1
print(f"Generated {len(chunks)} files for line {line_num} (speaker: {speaker})")
else:
# Generate a single file for this line
output_file = f"{file_counter:03d}-{args.output_base}.wav"
# Generate audio
wav = model.generate(text, audio_prompt_path=sample_path)
ta.save(output_file, wav, model.sr)
summary.append(f"File {output_file}: {speaker} - {text[:50]}...")
file_counter += 1
print(f"Generated file for line {line_num} (speaker: {speaker})")
# Print summary
print("\nSummary of generated files:")
for entry in summary:
print(entry)
if __name__ == '__main__':
main()