52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
import torch
|
|
import torchaudio as ta
|
|
from chatterbox.tts import ChatterboxTTS
|
|
|
|
# Detect device (Mac with M1/M2/M3/M4)
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
def safe_load_chatterbox_tts(device="mps"):
|
|
"""
|
|
Safely load ChatterboxTTS model with proper device mapping.
|
|
Handles cases where model was saved on CUDA but needs to be loaded on MPS/CPU.
|
|
"""
|
|
# Store original torch.load function
|
|
original_torch_load = torch.load
|
|
|
|
def patched_torch_load(f, map_location=None, **kwargs):
|
|
# If no map_location is specified and we're loading on non-CUDA device,
|
|
# map CUDA tensors to the target device
|
|
if map_location is None:
|
|
if device == "mps" and torch.backends.mps.is_available():
|
|
map_location = torch.device("mps")
|
|
elif device == "cpu" or not torch.cuda.is_available():
|
|
map_location = torch.device("cpu")
|
|
else:
|
|
map_location = torch.device(device)
|
|
|
|
return original_torch_load(f, map_location=map_location, **kwargs)
|
|
|
|
# Temporarily patch torch.load
|
|
torch.load = patched_torch_load
|
|
|
|
try:
|
|
# Load the model with the patched torch.load
|
|
model = ChatterboxTTS.from_pretrained(device=device)
|
|
return model
|
|
finally:
|
|
# Restore original torch.load
|
|
torch.load = original_torch_load
|
|
|
|
model = safe_load_chatterbox_tts(device=device)
|
|
text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."
|
|
|
|
# If you want to synthesize with a different voice, specify the audio prompt
|
|
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
|
|
wav = model.generate(
|
|
text,
|
|
audio_prompt_path=AUDIO_PROMPT_PATH,
|
|
exaggeration=2.0,
|
|
cfg_weight=0.5
|
|
)
|
|
ta.save("test-2.wav", wav, model.sr)
|