78 lines
2.6 KiB
Python
Executable File
78 lines
2.6 KiB
Python
Executable File
import argparse
|
|
import gc
|
|
import torch
|
|
import torchaudio as ta
|
|
from chatterbox.tts import ChatterboxTTS
|
|
from contextlib import contextmanager
|
|
|
|
# Import helper to fix Python path
|
|
import import_helper
|
|
|
|
def safe_load_chatterbox_tts(device):
|
|
"""
|
|
Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion.
|
|
This patches torch.load temporarily to map CUDA tensors to the appropriate device.
|
|
"""
|
|
@contextmanager
|
|
def patch_torch_load(target_device):
|
|
original_load = torch.load
|
|
|
|
def patched_load(*args, **kwargs):
|
|
# Add map_location to handle device mapping
|
|
if 'map_location' not in kwargs:
|
|
if target_device == "mps" and torch.backends.mps.is_available():
|
|
kwargs['map_location'] = torch.device('mps')
|
|
else:
|
|
kwargs['map_location'] = torch.device('cpu')
|
|
return original_load(*args, **kwargs)
|
|
|
|
torch.load = patched_load
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.load = original_load
|
|
|
|
with patch_torch_load(device):
|
|
return ChatterboxTTS.from_pretrained(device=device)
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Chatterbox TTS audio generation")
|
|
parser.add_argument('--sample', required=True, type=str, help='Prompt/reference audio file (e.g. .wav, .mp3) for the voice')
|
|
parser.add_argument('--output', required=True, type=str, help='Output audio file path (should end with .wav)')
|
|
parser.add_argument('--text', required=True, type=str, help='Text to synthesize')
|
|
parser.add_argument('--device', default="mps", choices=["mps", "cuda", "cpu"], help='Device to use for TTS (default: mps)')
|
|
args = parser.parse_args()
|
|
|
|
model = None
|
|
wav = None
|
|
|
|
try:
|
|
# Load model with safe device mapping
|
|
model = safe_load_chatterbox_tts(args.device)
|
|
|
|
# Generate the audio
|
|
with torch.no_grad():
|
|
wav = model.generate(args.text, audio_prompt_path=args.sample)
|
|
|
|
# Save to output .wav
|
|
ta.save(args.output, wav, model.sr)
|
|
print(f"Generated audio saved to {args.output}")
|
|
|
|
finally:
|
|
# Explicit cleanup
|
|
if wav is not None:
|
|
del wav
|
|
if model is not None:
|
|
del model
|
|
|
|
# Force cleanup
|
|
gc.collect()
|
|
if args.device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
elif args.device == "mps" and torch.backends.mps.is_available():
|
|
if hasattr(torch.mps, "empty_cache"):
|
|
torch.mps.empty_cache()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|