chatterbox-ui/backend/app/services/tts_service.py

207 lines
8.6 KiB
Python

import torch
import torchaudio
from typing import Optional
from chatterbox.tts import ChatterboxTTS
from pathlib import Path
import gc # Garbage collector for memory management
import os
from contextlib import contextmanager
# Import configuration
try:
from app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR
except ModuleNotFoundError:
# When imported from scripts at project root
from backend.app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR
# Use configuration for TTS output directory
TTS_OUTPUT_DIR = TTS_TEMP_OUTPUT_DIR
def safe_load_chatterbox_tts(device):
"""
Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion.
This patches torch.load temporarily to map CUDA tensors to the appropriate device.
"""
@contextmanager
def patch_torch_load(target_device):
original_load = torch.load
def patched_load(*args, **kwargs):
# Add map_location to handle device mapping
if 'map_location' not in kwargs:
if target_device == "mps" and torch.backends.mps.is_available():
kwargs['map_location'] = torch.device('mps')
else:
kwargs['map_location'] = torch.device('cpu')
return original_load(*args, **kwargs)
torch.load = patched_load
try:
yield
finally:
torch.load = original_load
with patch_torch_load(device):
return ChatterboxTTS.from_pretrained(device=device)
class TTSService:
def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda"
self.device = device
self.model = None
self._ensure_output_dir_exists()
def _ensure_output_dir_exists(self):
"""Ensures the TTS output directory exists."""
TTS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def load_model(self):
"""Loads the ChatterboxTTS model."""
if self.model is None:
print(f"Loading ChatterboxTTS model to device: {self.device}...")
try:
self.model = safe_load_chatterbox_tts(self.device)
print("ChatterboxTTS model loaded successfully.")
except Exception as e:
print(f"Error loading ChatterboxTTS model: {e}")
# Potentially raise an exception or handle appropriately
raise
else:
print("ChatterboxTTS model already loaded.")
def unload_model(self):
"""Unloads the model and clears memory."""
if self.model is not None:
print("Unloading ChatterboxTTS model and clearing cache...")
del self.model
self.model = None
if self.device == "cuda":
torch.cuda.empty_cache()
elif self.device == "mps":
if hasattr(torch.mps, "empty_cache"): # Check if empty_cache is available for MPS
torch.mps.empty_cache()
gc.collect() # Explicitly run garbage collection
print("Model unloaded and memory cleared.")
async def generate_speech(
self,
text: str,
speaker_sample_path: str, # Absolute path to the speaker's audio sample
output_filename_base: str, # e.g., "dialog_line_1_spk_X_chunk_0"
speaker_id: Optional[str] = None, # Optional, mainly for logging if needed, filename base is primary
output_dir: Optional[Path] = None, # Optional, defaults to TTS_OUTPUT_DIR from this module
exaggeration: float = 0.5, # Default from Gradio
cfg_weight: float = 0.5, # Default from Gradio
temperature: float = 0.8, # Default from Gradio
unload_after: bool = False, # Whether to unload the model after generation
) -> Path:
"""
Generates speech from text using the loaded TTS model and a speaker sample.
Saves the output to a .wav file.
"""
if self.model is None:
self.load_model()
if self.model is None: # Check again if loading failed
raise RuntimeError("TTS model is not loaded. Cannot generate speech.")
# Ensure speaker_sample_path is valid
speaker_sample_p = Path(speaker_sample_path)
if not speaker_sample_p.exists() or not speaker_sample_p.is_file():
raise FileNotFoundError(f"Speaker sample audio file not found: {speaker_sample_path}")
target_output_dir = output_dir if output_dir is not None else TTS_OUTPUT_DIR
target_output_dir.mkdir(parents=True, exist_ok=True)
# output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info)
output_file_path = target_output_dir / f"{output_filename_base}.wav"
print(f"Generating audio for text: \"{text[:50]}...\" with speaker sample: {speaker_sample_path}")
wav = None
try:
with torch.no_grad(): # Important for inference
wav = self.model.generate(
text=text,
audio_prompt_path=str(speaker_sample_p), # Must be a string path
exaggeration=exaggeration,
cfg_weight=cfg_weight,
temperature=temperature,
)
torchaudio.save(str(output_file_path), wav, self.model.sr)
print(f"Audio saved to: {output_file_path}")
return output_file_path
except Exception as e:
print(f"Error during TTS generation or saving: {e}")
raise
finally:
# Explicitly delete the wav tensor to free memory
if wav is not None:
del wav
# Force garbage collection and cache cleanup
gc.collect()
if self.device == "cuda":
torch.cuda.empty_cache()
elif self.device == "mps":
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
# Unload the model if requested
if unload_after:
print("Unloading TTS model after generation...")
self.unload_model()
# Example usage (for testing, not part of the service itself)
if __name__ == "__main__":
async def main_test():
tts_service = TTSService(device="mps")
try:
tts_service.load_model()
dummy_speaker_root = SPEAKER_SAMPLES_DIR
dummy_speaker_root.mkdir(parents=True, exist_ok=True)
dummy_sample_file = dummy_speaker_root / "dummy_speaker_test.wav"
import os # Added for os.remove
# Always try to remove an existing dummy file to ensure a fresh one is created
if dummy_sample_file.exists():
try:
os.remove(dummy_sample_file)
print(f"Removed existing dummy sample: {dummy_sample_file}")
except OSError as e:
print(f"Error removing existing dummy sample {dummy_sample_file}: {e}")
# Proceeding, but torchaudio.save might fail or overwrite
print(f"Creating new dummy speaker sample: {dummy_sample_file}")
# Create a minimal, silent WAV file for testing
sample_rate = 22050
duration = 1 # seconds
num_channels = 1
num_frames = sample_rate * duration
audio_data = torch.zeros((num_channels, num_frames))
try:
torchaudio.save(str(dummy_sample_file), audio_data, sample_rate)
print(f"Dummy sample created successfully: {dummy_sample_file}")
except Exception as save_e:
print(f"Could not create dummy sample: {save_e}")
# If creation fails, the subsequent generation test will likely also fail or be skipped.
if dummy_sample_file.exists():
output_path = await tts_service.generate_speech(
text="Hello, this is a test of the Text-to-Speech service.",
speaker_id="test_speaker",
speaker_sample_path=str(dummy_sample_file),
output_filename_base="test_generation"
)
print(f"Test generation output: {output_path}")
else:
print(f"Skipping generation test as dummy sample {dummy_sample_file} not found.")
except Exception as e:
import traceback
print(f"Error during TTS generation or saving:")
traceback.print_exc()
finally:
tts_service.unload_model()
import asyncio
asyncio.run(main_test())