147 lines
5.3 KiB
Python
147 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import re
|
|
import os
|
|
import yaml
|
|
import torchaudio as ta
|
|
from chatterbox.tts import ChatterboxTTS
|
|
|
|
# Import helper to fix Python path
|
|
import import_helper
|
|
|
|
def split_text_at_sentence_boundaries(text, max_length=300):
|
|
"""
|
|
Split text at sentence boundaries, ensuring each chunk is <= max_length.
|
|
Returns a list of text chunks.
|
|
"""
|
|
# Simple regex for sentence boundaries (period, question mark, exclamation mark followed by space or end)
|
|
sentence_pattern = r'[.!?](?:\s|$)'
|
|
sentences = re.split(f'({sentence_pattern})', text)
|
|
|
|
# Recombine the split parts (the regex split keeps the delimiters as separate items)
|
|
actual_sentences = []
|
|
current = ""
|
|
for i in range(0, len(sentences), 2):
|
|
if i+1 < len(sentences):
|
|
current = sentences[i] + sentences[i+1]
|
|
else:
|
|
current = sentences[i]
|
|
if current:
|
|
actual_sentences.append(current)
|
|
|
|
# Group sentences into chunks <= max_length
|
|
chunks = []
|
|
current_chunk = ""
|
|
|
|
for sentence in actual_sentences:
|
|
# If adding this sentence would exceed max_length and we already have content,
|
|
# finish the current chunk and start a new one
|
|
if len(current_chunk) + len(sentence) > max_length and current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence
|
|
else:
|
|
current_chunk += sentence
|
|
|
|
# Add the last chunk if it has content
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
def parse_dialog_line(line):
|
|
"""
|
|
Parse a dialog line in the format: Name: "Text"
|
|
Returns a tuple of (speaker, text) or None if the line doesn't match the pattern.
|
|
"""
|
|
pattern = r'^([^:]+):\s*"([^"]+)"$'
|
|
match = re.match(pattern, line.strip())
|
|
if match:
|
|
speaker = match.group(1).strip()
|
|
text = match.group(2).strip()
|
|
return (speaker, text)
|
|
return None
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Generate dialog audio from markdown file using Chatterbox TTS")
|
|
parser.add_argument('--config', required=True, type=str, help='YAML config file mapping speaker names to audio samples')
|
|
parser.add_argument('--dialog', required=True, type=str, help='Markdown dialog file')
|
|
parser.add_argument('--output-base', required=True, type=str, help='Base name for output files (e.g., "output" for "001-output.wav")')
|
|
parser.add_argument('--reinit-each-line', action='store_true', help='Re-initialize the model after each line to reduce memory usage')
|
|
args = parser.parse_args()
|
|
|
|
# Load the YAML config
|
|
with open(args.config, 'r') as f:
|
|
speaker_samples = yaml.safe_load(f)
|
|
|
|
# Load the dialog file
|
|
with open(args.dialog, 'r') as f:
|
|
dialog_lines = f.readlines()
|
|
|
|
# Initialize model only once if not reinitializing per line
|
|
model = None
|
|
if not args.reinit_each_line:
|
|
print("Loading ChatterboxTTS model once for all lines...")
|
|
model = ChatterboxTTS.from_pretrained(device="mps")
|
|
|
|
# Process each dialog line
|
|
file_counter = 1
|
|
summary = []
|
|
|
|
for line_num, line in enumerate(dialog_lines, 1):
|
|
parsed = parse_dialog_line(line)
|
|
if not parsed:
|
|
print(f"Skipping line {line_num}: Not in the expected format")
|
|
continue
|
|
|
|
speaker, text = parsed
|
|
|
|
# Check if the speaker is in the config
|
|
if speaker not in speaker_samples:
|
|
print(f"Warning: Speaker '{speaker}' not found in config, skipping line {line_num}")
|
|
continue
|
|
|
|
sample_path = speaker_samples[speaker]
|
|
|
|
# Reinitialize model if needed
|
|
if args.reinit_each_line or model is None:
|
|
if args.reinit_each_line:
|
|
print(f"Reinitializing model for line {line_num}...")
|
|
model = ChatterboxTTS.from_pretrained(device="mps")
|
|
|
|
# Check if the text needs to be split (> 300 chars)
|
|
if len(text) > 300:
|
|
chunks = split_text_at_sentence_boundaries(text)
|
|
chunk_files = []
|
|
|
|
for chunk in chunks:
|
|
output_file = f"{file_counter:03d}-{args.output_base}.wav"
|
|
|
|
# Generate audio for this chunk
|
|
wav = model.generate(chunk, audio_prompt_path=sample_path)
|
|
ta.save(output_file, wav, model.sr)
|
|
|
|
chunk_files.append(output_file)
|
|
summary.append(f"File {output_file}: {speaker} (chunk) - {chunk[:50]}...")
|
|
file_counter += 1
|
|
|
|
print(f"Generated {len(chunks)} files for line {line_num} (speaker: {speaker})")
|
|
else:
|
|
# Generate a single file for this line
|
|
output_file = f"{file_counter:03d}-{args.output_base}.wav"
|
|
|
|
# Generate audio
|
|
wav = model.generate(text, audio_prompt_path=sample_path)
|
|
ta.save(output_file, wav, model.sr)
|
|
|
|
summary.append(f"File {output_file}: {speaker} - {text[:50]}...")
|
|
file_counter += 1
|
|
print(f"Generated file for line {line_num} (speaker: {speaker})")
|
|
|
|
# Print summary
|
|
print("\nSummary of generated files:")
|
|
for entry in summary:
|
|
print(entry)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|