chatterbox-ui/backend/app/services/speaker_service.py

148 lines
6.7 KiB
Python

import yaml
import uuid
import os
import io # Added for BytesIO
import torchaudio # Added for audio processing
from pathlib import Path
from typing import List, Dict, Optional, Any
from fastapi import UploadFile, HTTPException
from app.models.speaker_models import Speaker, SpeakerCreate
from app import config
class SpeakerManagementService:
def __init__(self):
self._ensure_data_files_exist()
self.speakers_data = self._load_speakers_data()
def _ensure_data_files_exist(self):
"""Ensures the speaker data directory and YAML file exist."""
config.SPEAKER_DATA_BASE_DIR.mkdir(parents=True, exist_ok=True)
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
if not config.SPEAKERS_YAML_FILE.exists():
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
yaml.dump({}, f) # Initialize with an empty dict, as per previous fixes
def _load_speakers_data(self) -> Dict[str, Any]: # Changed return type to Dict
"""Loads speaker data from the YAML file."""
try:
with open(config.SPEAKERS_YAML_FILE, 'r') as f:
data = yaml.safe_load(f)
return data if isinstance(data, dict) else {} # Ensure it's a dict
except FileNotFoundError:
return {}
except yaml.YAMLError:
# Handle corrupted YAML file, e.g., log error and return empty list
print(f"Error: Corrupted speakers YAML file at {config.SPEAKERS_YAML_FILE}")
return {}
def _save_speakers_data(self):
"""Saves the current speaker data to the YAML file."""
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
yaml.dump(self.speakers_data, f, sort_keys=False)
def get_speakers(self) -> List[Speaker]:
"""Returns a list of all speakers."""
# self.speakers_data is now a dict: {speaker_id: {name: ..., sample_path: ...}}
return [Speaker(id=spk_id, **spk_attrs) for spk_id, spk_attrs in self.speakers_data.items()]
def get_speaker_by_id(self, speaker_id: str) -> Optional[Speaker]:
"""Retrieves a speaker by their ID."""
if speaker_id in self.speakers_data:
speaker_attributes = self.speakers_data[speaker_id]
return Speaker(id=speaker_id, **speaker_attributes)
return None
async def add_speaker(self, name: str, audio_file: UploadFile) -> Speaker:
"""Adds a new speaker, converts sample to WAV, saves it, and updates YAML."""
speaker_id = str(uuid.uuid4())
# Define standardized sample filename and path (always WAV)
sample_filename = f"{speaker_id}.wav"
sample_path = config.SPEAKER_SAMPLES_DIR / sample_filename
try:
content = await audio_file.read()
# Use BytesIO to handle the in-memory audio data for torchaudio
audio_buffer = io.BytesIO(content)
# Load audio data using torchaudio, this handles various formats (MP3, WAV, etc.)
# waveform is a tensor, sample_rate is an int
waveform, sample_rate = torchaudio.load(audio_buffer)
# Save the audio data as WAV
# Ensure the SPEAKER_SAMPLES_DIR exists (though _ensure_data_files_exist should handle it)
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(sample_path), waveform, sample_rate, format="wav")
except torchaudio.TorchaudioException as e:
# More specific error for torchaudio issues (e.g. unsupported format, corrupted file)
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format (e.g., WAV, MP3).")
except Exception as e:
# General error handling for other issues (e.g., file system errors)
raise HTTPException(status_code=500, detail=f"Could not save audio file: {e}")
finally:
await audio_file.close()
new_speaker_data = {
"id": speaker_id,
"name": name,
"sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)) # Store path relative to speaker_data dir
}
# self.speakers_data is now a dict
self.speakers_data[speaker_id] = {
"name": name,
"sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR))
}
self._save_speakers_data()
# Construct Speaker model for return, including the ID
return Speaker(id=speaker_id, name=name, sample_path=str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)))
def delete_speaker(self, speaker_id: str) -> bool:
"""Deletes a speaker and their audio sample."""
# Speaker data is now a dictionary, keyed by speaker_id
speaker_to_delete = self.speakers_data.pop(speaker_id, None)
if speaker_to_delete:
self._save_speakers_data()
sample_path_str = speaker_to_delete.get("sample_path")
if sample_path_str:
# sample_path_str is relative to SPEAKER_DATA_BASE_DIR
full_sample_path = config.SPEAKER_DATA_BASE_DIR / sample_path_str
try:
if full_sample_path.is_file(): # Check if it's a file before removing
os.remove(full_sample_path)
except OSError as e:
# Log error if file deletion fails but proceed
print(f"Error deleting sample file {full_sample_path}: {e}")
return True
return False
# Example usage (for testing, not part of the service itself)
if __name__ == "__main__":
service = SpeakerManagementService()
print("Initial speakers:", service.get_speakers())
# This part would require a mock UploadFile to run directly
# print("\nAdding a new speaker (manual test setup needed for UploadFile)")
# class MockUploadFile:
# def __init__(self, filename, content):
# self.filename = filename
# self._content = content
# async def read(self): return self._content
# async def close(self): pass
# import asyncio
# async def test_add():
# mock_file = MockUploadFile("test.wav", b"dummy audio content")
# new_speaker = await service.add_speaker(name="Test Speaker", audio_file=mock_file)
# print("\nAdded speaker:", new_speaker)
# print("Speakers after add:", service.get_speakers())
# return new_speaker.id
# speaker_id_to_delete = asyncio.run(test_add())
# if speaker_id_to_delete:
# print(f"\nDeleting speaker {speaker_id_to_delete}")
# service.delete_speaker(speaker_id_to_delete)
# print("Speakers after delete:", service.get_speakers())