import yaml import uuid import os import io # Added for BytesIO import torchaudio # Added for audio processing from pathlib import Path from typing import List, Dict, Optional, Any from fastapi import UploadFile, HTTPException try: from app.models.speaker_models import Speaker, SpeakerCreate from app import config except ModuleNotFoundError: # When imported from scripts at project root from backend.app.models.speaker_models import Speaker, SpeakerCreate from backend.app import config class SpeakerManagementService: def __init__(self): self._ensure_data_files_exist() self.speakers_data = self._load_speakers_data() def _ensure_data_files_exist(self): """Ensures the speaker data directory and YAML file exist.""" config.SPEAKER_DATA_BASE_DIR.mkdir(parents=True, exist_ok=True) config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True) if not config.SPEAKERS_YAML_FILE.exists(): with open(config.SPEAKERS_YAML_FILE, 'w') as f: yaml.dump({}, f) # Initialize with an empty dict, as per previous fixes def _load_speakers_data(self) -> Dict[str, Any]: # Changed return type to Dict """Loads speaker data from the YAML file.""" try: with open(config.SPEAKERS_YAML_FILE, 'r') as f: data = yaml.safe_load(f) return data if isinstance(data, dict) else {} # Ensure it's a dict except FileNotFoundError: return {} except yaml.YAMLError: # Handle corrupted YAML file, e.g., log error and return empty list print(f"Error: Corrupted speakers YAML file at {config.SPEAKERS_YAML_FILE}") return {} def _save_speakers_data(self): """Saves the current speaker data to the YAML file.""" with open(config.SPEAKERS_YAML_FILE, 'w') as f: yaml.dump(self.speakers_data, f, sort_keys=False) def get_speakers(self) -> List[Speaker]: """Returns a list of all speakers.""" # self.speakers_data is now a dict: {speaker_id: {name: ..., sample_path: ...}} return [Speaker(id=spk_id, **spk_attrs) for spk_id, spk_attrs in self.speakers_data.items()] def get_speaker_by_id(self, speaker_id: str) -> Optional[Speaker]: """Retrieves a speaker by their ID.""" if speaker_id in self.speakers_data: speaker_attributes = self.speakers_data[speaker_id] return Speaker(id=speaker_id, **speaker_attributes) return None async def add_speaker(self, name: str, audio_file: UploadFile, reference_text: str) -> Speaker: """Add a new speaker for Higgs TTS""" # Validate required reference text if not reference_text or not reference_text.strip(): raise HTTPException( status_code=400, detail="reference_text is required for Higgs TTS" ) # Validate reference text length if len(reference_text.strip()) > 500: raise HTTPException( status_code=400, detail="reference_text should be under 500 characters for optimal performance" ) speaker_id = str(uuid.uuid4()) # Define standardized sample filename and path (always WAV) sample_filename = f"{speaker_id}.wav" sample_path = config.SPEAKER_SAMPLES_DIR / sample_filename try: content = await audio_file.read() # Use BytesIO to handle the in-memory audio data for torchaudio audio_buffer = io.BytesIO(content) # Load audio data using torchaudio, this handles various formats (MP3, WAV, etc.) # waveform is a tensor, sample_rate is an int waveform, sample_rate = torchaudio.load(audio_buffer) # Save the audio data as WAV # Ensure the SPEAKER_SAMPLES_DIR exists (though _ensure_data_files_exist should handle it) config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True) torchaudio.save(str(sample_path), waveform, sample_rate, format="wav") except torchaudio.TorchaudioException as e: # More specific error for torchaudio issues (e.g. unsupported format, corrupted file) raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format (e.g., WAV, MP3).") except Exception as e: # General error handling for other issues (e.g., file system errors) raise HTTPException(status_code=500, detail=f"Could not save audio file: {e}") finally: await audio_file.close() # Clean reference text cleaned_reference_text = reference_text.strip() if reference_text else None new_speaker_data = { "name": name, "sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)), "reference_text": cleaned_reference_text } # Store in speakers_data dict self.speakers_data[speaker_id] = new_speaker_data self._save_speakers_data() # Construct Speaker model for return, including the ID return Speaker(id=speaker_id, **new_speaker_data) def delete_speaker(self, speaker_id: str) -> bool: """Deletes a speaker and their audio sample.""" # Speaker data is now a dictionary, keyed by speaker_id speaker_to_delete = self.speakers_data.pop(speaker_id, None) if speaker_to_delete: self._save_speakers_data() sample_path_str = speaker_to_delete.get("sample_path") if sample_path_str: # sample_path_str is relative to SPEAKER_DATA_BASE_DIR full_sample_path = config.SPEAKER_DATA_BASE_DIR / sample_path_str try: if full_sample_path.is_file(): # Check if it's a file before removing os.remove(full_sample_path) except OSError as e: # Log error if file deletion fails but proceed print(f"Error deleting sample file {full_sample_path}: {e}") return True return False def validate_all_speakers(self) -> dict: """Validate all speakers against current requirements""" validation_results = { "total_speakers": len(self.speakers_data), "valid_speakers": 0, "invalid_speakers": 0, "validation_errors": [] } for speaker_id, speaker_data in self.speakers_data.items(): try: # Create Speaker model instance to validate speaker = Speaker(id=speaker_id, **speaker_data) validation_results["valid_speakers"] += 1 except Exception as e: validation_results["invalid_speakers"] += 1 validation_results["validation_errors"].append({ "speaker_id": speaker_id, "speaker_name": speaker_data.get("name", "Unknown"), "error": str(e) }) return validation_results # Example usage (for testing, not part of the service itself) if __name__ == "__main__": service = SpeakerManagementService() print("Initial speakers:", service.get_speakers()) # This part would require a mock UploadFile to run directly # print("\nAdding a new speaker (manual test setup needed for UploadFile)") # class MockUploadFile: # def __init__(self, filename, content): # self.filename = filename # self._content = content # async def read(self): return self._content # async def close(self): pass # import asyncio # async def test_add(): # mock_file = MockUploadFile("test.wav", b"dummy audio content") # new_speaker = await service.add_speaker(name="Test Speaker", audio_file=mock_file) # print("\nAdded speaker:", new_speaker) # print("Speakers after add:", service.get_speakers()) # return new_speaker.id # speaker_id_to_delete = asyncio.run(test_add()) # if speaker_id_to_delete: # print(f"\nDeleting speaker {speaker_id_to_delete}") # service.delete_speaker(speaker_id_to_delete) # print("Speakers after delete:", service.get_speakers())