chatterbox-ui/test.py

52 lines
1.9 KiB
Python

import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
# Detect device (Mac with M1/M2/M3/M4)
device = "mps" if torch.backends.mps.is_available() else "cpu"
def safe_load_chatterbox_tts(device="mps"):
"""
Safely load ChatterboxTTS model with proper device mapping.
Handles cases where model was saved on CUDA but needs to be loaded on MPS/CPU.
"""
# Store original torch.load function
original_torch_load = torch.load
def patched_torch_load(f, map_location=None, **kwargs):
# If no map_location is specified and we're loading on non-CUDA device,
# map CUDA tensors to the target device
if map_location is None:
if device == "mps" and torch.backends.mps.is_available():
map_location = torch.device("mps")
elif device == "cpu" or not torch.cuda.is_available():
map_location = torch.device("cpu")
else:
map_location = torch.device(device)
return original_torch_load(f, map_location=map_location, **kwargs)
# Temporarily patch torch.load
torch.load = patched_torch_load
try:
# Load the model with the patched torch.load
model = ChatterboxTTS.from_pretrained(device=device)
return model
finally:
# Restore original torch.load
torch.load = original_torch_load
model = safe_load_chatterbox_tts(device=device)
text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(
text,
audio_prompt_path=AUDIO_PROMPT_PATH,
exaggeration=2.0,
cfg_weight=0.5
)
ta.save("test-2.wav", wav, model.sr)