chatterbox-ui/cbx-generate.py

78 lines
2.6 KiB
Python
Executable File

import argparse
import gc
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from contextlib import contextmanager
# Import helper to fix Python path
import import_helper
def safe_load_chatterbox_tts(device):
"""
Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion.
This patches torch.load temporarily to map CUDA tensors to the appropriate device.
"""
@contextmanager
def patch_torch_load(target_device):
original_load = torch.load
def patched_load(*args, **kwargs):
# Add map_location to handle device mapping
if 'map_location' not in kwargs:
if target_device == "mps" and torch.backends.mps.is_available():
kwargs['map_location'] = torch.device('mps')
else:
kwargs['map_location'] = torch.device('cpu')
return original_load(*args, **kwargs)
torch.load = patched_load
try:
yield
finally:
torch.load = original_load
with patch_torch_load(device):
return ChatterboxTTS.from_pretrained(device=device)
def main():
parser = argparse.ArgumentParser(description="Chatterbox TTS audio generation")
parser.add_argument('--sample', required=True, type=str, help='Prompt/reference audio file (e.g. .wav, .mp3) for the voice')
parser.add_argument('--output', required=True, type=str, help='Output audio file path (should end with .wav)')
parser.add_argument('--text', required=True, type=str, help='Text to synthesize')
parser.add_argument('--device', default="mps", choices=["mps", "cuda", "cpu"], help='Device to use for TTS (default: mps)')
args = parser.parse_args()
model = None
wav = None
try:
# Load model with safe device mapping
model = safe_load_chatterbox_tts(args.device)
# Generate the audio
with torch.no_grad():
wav = model.generate(args.text, audio_prompt_path=args.sample)
# Save to output .wav
ta.save(args.output, wav, model.sr)
print(f"Generated audio saved to {args.output}")
finally:
# Explicit cleanup
if wav is not None:
del wav
if model is not None:
del model
# Force cleanup
gc.collect()
if args.device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif args.device == "mps" and torch.backends.mps.is_available():
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
if __name__ == '__main__':
main()