import argparse import gc import torch import torchaudio as ta from chatterbox.tts import ChatterboxTTS from contextlib import contextmanager # Import helper to fix Python path import import_helper def safe_load_chatterbox_tts(device): """ Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion. This patches torch.load temporarily to map CUDA tensors to the appropriate device. """ @contextmanager def patch_torch_load(target_device): original_load = torch.load def patched_load(*args, **kwargs): # Add map_location to handle device mapping if 'map_location' not in kwargs: if target_device == "mps" and torch.backends.mps.is_available(): kwargs['map_location'] = torch.device('mps') else: kwargs['map_location'] = torch.device('cpu') return original_load(*args, **kwargs) torch.load = patched_load try: yield finally: torch.load = original_load with patch_torch_load(device): return ChatterboxTTS.from_pretrained(device=device) def main(): parser = argparse.ArgumentParser(description="Chatterbox TTS audio generation") parser.add_argument('--sample', required=True, type=str, help='Prompt/reference audio file (e.g. .wav, .mp3) for the voice') parser.add_argument('--output', required=True, type=str, help='Output audio file path (should end with .wav)') parser.add_argument('--text', required=True, type=str, help='Text to synthesize') parser.add_argument('--device', default="mps", choices=["mps", "cuda", "cpu"], help='Device to use for TTS (default: mps)') args = parser.parse_args() model = None wav = None try: # Load model with safe device mapping model = safe_load_chatterbox_tts(args.device) # Generate the audio with torch.no_grad(): wav = model.generate(args.text, audio_prompt_path=args.sample) # Save to output .wav ta.save(args.output, wav, model.sr) print(f"Generated audio saved to {args.output}") finally: # Explicit cleanup if wav is not None: del wav if model is not None: del model # Force cleanup gc.collect() if args.device == "cuda" and torch.cuda.is_available(): torch.cuda.empty_cache() elif args.device == "mps" and torch.backends.mps.is_available(): if hasattr(torch.mps, "empty_cache"): torch.mps.empty_cache() if __name__ == '__main__': main()