chatterbox-ui/backend/app/services/speaker_service.py

193 lines
8.3 KiB
Python

import yaml
import uuid
import os
import io # Added for BytesIO
import torchaudio # Added for audio processing
from pathlib import Path
from typing import List, Dict, Optional, Any
from fastapi import UploadFile, HTTPException
try:
from app.models.speaker_models import Speaker, SpeakerCreate
from app import config
except ModuleNotFoundError:
# When imported from scripts at project root
from backend.app.models.speaker_models import Speaker, SpeakerCreate
from backend.app import config
class SpeakerManagementService:
def __init__(self):
self._ensure_data_files_exist()
self.speakers_data = self._load_speakers_data()
def _ensure_data_files_exist(self):
"""Ensures the speaker data directory and YAML file exist."""
config.SPEAKER_DATA_BASE_DIR.mkdir(parents=True, exist_ok=True)
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
if not config.SPEAKERS_YAML_FILE.exists():
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
yaml.dump({}, f) # Initialize with an empty dict, as per previous fixes
def _load_speakers_data(self) -> Dict[str, Any]: # Changed return type to Dict
"""Loads speaker data from the YAML file."""
try:
with open(config.SPEAKERS_YAML_FILE, 'r') as f:
data = yaml.safe_load(f)
return data if isinstance(data, dict) else {} # Ensure it's a dict
except FileNotFoundError:
return {}
except yaml.YAMLError:
# Handle corrupted YAML file, e.g., log error and return empty list
print(f"Error: Corrupted speakers YAML file at {config.SPEAKERS_YAML_FILE}")
return {}
def _save_speakers_data(self):
"""Saves the current speaker data to the YAML file."""
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
yaml.dump(self.speakers_data, f, sort_keys=False)
def get_speakers(self) -> List[Speaker]:
"""Returns a list of all speakers."""
# self.speakers_data is now a dict: {speaker_id: {name: ..., sample_path: ...}}
return [Speaker(id=spk_id, **spk_attrs) for spk_id, spk_attrs in self.speakers_data.items()]
def get_speaker_by_id(self, speaker_id: str) -> Optional[Speaker]:
"""Retrieves a speaker by their ID."""
if speaker_id in self.speakers_data:
speaker_attributes = self.speakers_data[speaker_id]
return Speaker(id=speaker_id, **speaker_attributes)
return None
async def add_speaker(self, name: str, audio_file: UploadFile,
reference_text: str) -> Speaker:
"""Add a new speaker for Higgs TTS"""
# Validate required reference text
if not reference_text or not reference_text.strip():
raise HTTPException(
status_code=400,
detail="reference_text is required for Higgs TTS"
)
# Validate reference text length
if len(reference_text.strip()) > 500:
raise HTTPException(
status_code=400,
detail="reference_text should be under 500 characters for optimal performance"
)
speaker_id = str(uuid.uuid4())
# Define standardized sample filename and path (always WAV)
sample_filename = f"{speaker_id}.wav"
sample_path = config.SPEAKER_SAMPLES_DIR / sample_filename
try:
content = await audio_file.read()
# Use BytesIO to handle the in-memory audio data for torchaudio
audio_buffer = io.BytesIO(content)
# Load audio data using torchaudio, this handles various formats (MP3, WAV, etc.)
# waveform is a tensor, sample_rate is an int
waveform, sample_rate = torchaudio.load(audio_buffer)
# Save the audio data as WAV
# Ensure the SPEAKER_SAMPLES_DIR exists (though _ensure_data_files_exist should handle it)
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(sample_path), waveform, sample_rate, format="wav")
except torchaudio.TorchaudioException as e:
# More specific error for torchaudio issues (e.g. unsupported format, corrupted file)
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format (e.g., WAV, MP3).")
except Exception as e:
# General error handling for other issues (e.g., file system errors)
raise HTTPException(status_code=500, detail=f"Could not save audio file: {e}")
finally:
await audio_file.close()
# Clean reference text
cleaned_reference_text = reference_text.strip() if reference_text else None
new_speaker_data = {
"name": name,
"sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)),
"reference_text": cleaned_reference_text
}
# Store in speakers_data dict
self.speakers_data[speaker_id] = new_speaker_data
self._save_speakers_data()
# Construct Speaker model for return, including the ID
return Speaker(id=speaker_id, **new_speaker_data)
def delete_speaker(self, speaker_id: str) -> bool:
"""Deletes a speaker and their audio sample."""
# Speaker data is now a dictionary, keyed by speaker_id
speaker_to_delete = self.speakers_data.pop(speaker_id, None)
if speaker_to_delete:
self._save_speakers_data()
sample_path_str = speaker_to_delete.get("sample_path")
if sample_path_str:
# sample_path_str is relative to SPEAKER_DATA_BASE_DIR
full_sample_path = config.SPEAKER_DATA_BASE_DIR / sample_path_str
try:
if full_sample_path.is_file(): # Check if it's a file before removing
os.remove(full_sample_path)
except OSError as e:
# Log error if file deletion fails but proceed
print(f"Error deleting sample file {full_sample_path}: {e}")
return True
return False
def validate_all_speakers(self) -> dict:
"""Validate all speakers against current requirements"""
validation_results = {
"total_speakers": len(self.speakers_data),
"valid_speakers": 0,
"invalid_speakers": 0,
"validation_errors": []
}
for speaker_id, speaker_data in self.speakers_data.items():
try:
# Create Speaker model instance to validate
speaker = Speaker(id=speaker_id, **speaker_data)
validation_results["valid_speakers"] += 1
except Exception as e:
validation_results["invalid_speakers"] += 1
validation_results["validation_errors"].append({
"speaker_id": speaker_id,
"speaker_name": speaker_data.get("name", "Unknown"),
"error": str(e)
})
return validation_results
# Example usage (for testing, not part of the service itself)
if __name__ == "__main__":
service = SpeakerManagementService()
print("Initial speakers:", service.get_speakers())
# This part would require a mock UploadFile to run directly
# print("\nAdding a new speaker (manual test setup needed for UploadFile)")
# class MockUploadFile:
# def __init__(self, filename, content):
# self.filename = filename
# self._content = content
# async def read(self): return self._content
# async def close(self): pass
# import asyncio
# async def test_add():
# mock_file = MockUploadFile("test.wav", b"dummy audio content")
# new_speaker = await service.add_speaker(name="Test Speaker", audio_file=mock_file)
# print("\nAdded speaker:", new_speaker)
# print("Speakers after add:", service.get_speakers())
# return new_speaker.id
# speaker_id_to_delete = asyncio.run(test_add())
# if speaker_id_to_delete:
# print(f"\nDeleting speaker {speaker_id_to_delete}")
# service.delete_speaker(speaker_id_to_delete)
# print("Speakers after delete:", service.get_speakers())