import torch import torchaudio from typing import Optional from chatterbox.tts import ChatterboxTTS from pathlib import Path import gc # Garbage collector for memory management import os from contextlib import contextmanager # Import configuration try: from app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR except ModuleNotFoundError: # When imported from scripts at project root from backend.app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR # Use configuration for TTS output directory TTS_OUTPUT_DIR = TTS_TEMP_OUTPUT_DIR def safe_load_chatterbox_tts(device): """ Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion. This patches torch.load temporarily to map CUDA tensors to the appropriate device. """ @contextmanager def patch_torch_load(target_device): original_load = torch.load def patched_load(*args, **kwargs): # Add map_location to handle device mapping if 'map_location' not in kwargs: if target_device == "mps" and torch.backends.mps.is_available(): kwargs['map_location'] = torch.device('mps') else: kwargs['map_location'] = torch.device('cpu') return original_load(*args, **kwargs) torch.load = patched_load try: yield finally: torch.load = original_load with patch_torch_load(device): return ChatterboxTTS.from_pretrained(device=device) class TTSService: def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda" self.device = device self.model = None self._ensure_output_dir_exists() def _ensure_output_dir_exists(self): """Ensures the TTS output directory exists.""" TTS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) def load_model(self): """Loads the ChatterboxTTS model.""" if self.model is None: print(f"Loading ChatterboxTTS model to device: {self.device}...") try: self.model = safe_load_chatterbox_tts(self.device) print("ChatterboxTTS model loaded successfully.") except Exception as e: print(f"Error loading ChatterboxTTS model: {e}") # Potentially raise an exception or handle appropriately raise else: print("ChatterboxTTS model already loaded.") def unload_model(self): """Unloads the model and clears memory.""" if self.model is not None: print("Unloading ChatterboxTTS model and clearing cache...") del self.model self.model = None if self.device == "cuda": torch.cuda.empty_cache() elif self.device == "mps": if hasattr(torch.mps, "empty_cache"): # Check if empty_cache is available for MPS torch.mps.empty_cache() gc.collect() # Explicitly run garbage collection print("Model unloaded and memory cleared.") async def generate_speech( self, text: str, speaker_sample_path: str, # Absolute path to the speaker's audio sample output_filename_base: str, # e.g., "dialog_line_1_spk_X_chunk_0" speaker_id: Optional[str] = None, # Optional, mainly for logging if needed, filename base is primary output_dir: Optional[Path] = None, # Optional, defaults to TTS_OUTPUT_DIR from this module exaggeration: float = 0.5, # Default from Gradio cfg_weight: float = 0.5, # Default from Gradio temperature: float = 0.8, # Default from Gradio unload_after: bool = False, # Whether to unload the model after generation ) -> Path: """ Generates speech from text using the loaded TTS model and a speaker sample. Saves the output to a .wav file. """ if self.model is None: self.load_model() if self.model is None: # Check again if loading failed raise RuntimeError("TTS model is not loaded. Cannot generate speech.") # Ensure speaker_sample_path is valid speaker_sample_p = Path(speaker_sample_path) if not speaker_sample_p.exists() or not speaker_sample_p.is_file(): raise FileNotFoundError(f"Speaker sample audio file not found: {speaker_sample_path}") target_output_dir = output_dir if output_dir is not None else TTS_OUTPUT_DIR target_output_dir.mkdir(parents=True, exist_ok=True) # output_filename_base from DialogProcessorService is expected to be comprehensive (e.g., includes speaker_id, segment info) output_file_path = target_output_dir / f"{output_filename_base}.wav" print(f"Generating audio for text: \"{text[:50]}...\" with speaker sample: {speaker_sample_path}") wav = None try: with torch.no_grad(): # Important for inference wav = self.model.generate( text=text, audio_prompt_path=str(speaker_sample_p), # Must be a string path exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature, ) torchaudio.save(str(output_file_path), wav, self.model.sr) print(f"Audio saved to: {output_file_path}") return output_file_path except Exception as e: print(f"Error during TTS generation or saving: {e}") raise finally: # Explicitly delete the wav tensor to free memory if wav is not None: del wav # Force garbage collection and cache cleanup gc.collect() if self.device == "cuda": torch.cuda.empty_cache() elif self.device == "mps": if hasattr(torch.mps, "empty_cache"): torch.mps.empty_cache() # Unload the model if requested if unload_after: print("Unloading TTS model after generation...") self.unload_model() # Example usage (for testing, not part of the service itself) if __name__ == "__main__": async def main_test(): tts_service = TTSService(device="mps") try: tts_service.load_model() dummy_speaker_root = SPEAKER_SAMPLES_DIR dummy_speaker_root.mkdir(parents=True, exist_ok=True) dummy_sample_file = dummy_speaker_root / "dummy_speaker_test.wav" import os # Added for os.remove # Always try to remove an existing dummy file to ensure a fresh one is created if dummy_sample_file.exists(): try: os.remove(dummy_sample_file) print(f"Removed existing dummy sample: {dummy_sample_file}") except OSError as e: print(f"Error removing existing dummy sample {dummy_sample_file}: {e}") # Proceeding, but torchaudio.save might fail or overwrite print(f"Creating new dummy speaker sample: {dummy_sample_file}") # Create a minimal, silent WAV file for testing sample_rate = 22050 duration = 1 # seconds num_channels = 1 num_frames = sample_rate * duration audio_data = torch.zeros((num_channels, num_frames)) try: torchaudio.save(str(dummy_sample_file), audio_data, sample_rate) print(f"Dummy sample created successfully: {dummy_sample_file}") except Exception as save_e: print(f"Could not create dummy sample: {save_e}") # If creation fails, the subsequent generation test will likely also fail or be skipped. if dummy_sample_file.exists(): output_path = await tts_service.generate_speech( text="Hello, this is a test of the Text-to-Speech service.", speaker_id="test_speaker", speaker_sample_path=str(dummy_sample_file), output_filename_base="test_generation" ) print(f"Test generation output: {output_path}") else: print(f"Skipping generation test as dummy sample {dummy_sample_file} not found.") except Exception as e: import traceback print(f"Error during TTS generation or saving:") traceback.print_exc() finally: tts_service.unload_model() import asyncio asyncio.run(main_test())