148 lines
6.7 KiB
Python
148 lines
6.7 KiB
Python
import yaml
|
|
import uuid
|
|
import os
|
|
import io # Added for BytesIO
|
|
import torchaudio # Added for audio processing
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional, Any
|
|
|
|
from fastapi import UploadFile, HTTPException
|
|
from app.models.speaker_models import Speaker, SpeakerCreate
|
|
from app import config
|
|
|
|
class SpeakerManagementService:
|
|
def __init__(self):
|
|
self._ensure_data_files_exist()
|
|
self.speakers_data = self._load_speakers_data()
|
|
|
|
def _ensure_data_files_exist(self):
|
|
"""Ensures the speaker data directory and YAML file exist."""
|
|
config.SPEAKER_DATA_BASE_DIR.mkdir(parents=True, exist_ok=True)
|
|
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
|
if not config.SPEAKERS_YAML_FILE.exists():
|
|
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
|
|
yaml.dump({}, f) # Initialize with an empty dict, as per previous fixes
|
|
|
|
def _load_speakers_data(self) -> Dict[str, Any]: # Changed return type to Dict
|
|
"""Loads speaker data from the YAML file."""
|
|
try:
|
|
with open(config.SPEAKERS_YAML_FILE, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
return data if isinstance(data, dict) else {} # Ensure it's a dict
|
|
except FileNotFoundError:
|
|
return {}
|
|
except yaml.YAMLError:
|
|
# Handle corrupted YAML file, e.g., log error and return empty list
|
|
print(f"Error: Corrupted speakers YAML file at {config.SPEAKERS_YAML_FILE}")
|
|
return {}
|
|
|
|
|
|
def _save_speakers_data(self):
|
|
"""Saves the current speaker data to the YAML file."""
|
|
with open(config.SPEAKERS_YAML_FILE, 'w') as f:
|
|
yaml.dump(self.speakers_data, f, sort_keys=False)
|
|
|
|
def get_speakers(self) -> List[Speaker]:
|
|
"""Returns a list of all speakers."""
|
|
# self.speakers_data is now a dict: {speaker_id: {name: ..., sample_path: ...}}
|
|
return [Speaker(id=spk_id, **spk_attrs) for spk_id, spk_attrs in self.speakers_data.items()]
|
|
|
|
def get_speaker_by_id(self, speaker_id: str) -> Optional[Speaker]:
|
|
"""Retrieves a speaker by their ID."""
|
|
if speaker_id in self.speakers_data:
|
|
speaker_attributes = self.speakers_data[speaker_id]
|
|
return Speaker(id=speaker_id, **speaker_attributes)
|
|
return None
|
|
|
|
async def add_speaker(self, name: str, audio_file: UploadFile) -> Speaker:
|
|
"""Adds a new speaker, converts sample to WAV, saves it, and updates YAML."""
|
|
speaker_id = str(uuid.uuid4())
|
|
|
|
# Define standardized sample filename and path (always WAV)
|
|
sample_filename = f"{speaker_id}.wav"
|
|
sample_path = config.SPEAKER_SAMPLES_DIR / sample_filename
|
|
|
|
try:
|
|
content = await audio_file.read()
|
|
# Use BytesIO to handle the in-memory audio data for torchaudio
|
|
audio_buffer = io.BytesIO(content)
|
|
|
|
# Load audio data using torchaudio, this handles various formats (MP3, WAV, etc.)
|
|
# waveform is a tensor, sample_rate is an int
|
|
waveform, sample_rate = torchaudio.load(audio_buffer)
|
|
|
|
# Save the audio data as WAV
|
|
# Ensure the SPEAKER_SAMPLES_DIR exists (though _ensure_data_files_exist should handle it)
|
|
config.SPEAKER_SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
|
torchaudio.save(str(sample_path), waveform, sample_rate, format="wav")
|
|
|
|
except torchaudio.TorchaudioException as e:
|
|
# More specific error for torchaudio issues (e.g. unsupported format, corrupted file)
|
|
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format (e.g., WAV, MP3).")
|
|
except Exception as e:
|
|
# General error handling for other issues (e.g., file system errors)
|
|
raise HTTPException(status_code=500, detail=f"Could not save audio file: {e}")
|
|
finally:
|
|
await audio_file.close()
|
|
|
|
new_speaker_data = {
|
|
"id": speaker_id,
|
|
"name": name,
|
|
"sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)) # Store path relative to speaker_data dir
|
|
}
|
|
|
|
# self.speakers_data is now a dict
|
|
self.speakers_data[speaker_id] = {
|
|
"name": name,
|
|
"sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR))
|
|
}
|
|
self._save_speakers_data()
|
|
# Construct Speaker model for return, including the ID
|
|
return Speaker(id=speaker_id, name=name, sample_path=str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)))
|
|
|
|
def delete_speaker(self, speaker_id: str) -> bool:
|
|
"""Deletes a speaker and their audio sample."""
|
|
# Speaker data is now a dictionary, keyed by speaker_id
|
|
speaker_to_delete = self.speakers_data.pop(speaker_id, None)
|
|
|
|
if speaker_to_delete:
|
|
self._save_speakers_data()
|
|
sample_path_str = speaker_to_delete.get("sample_path")
|
|
if sample_path_str:
|
|
# sample_path_str is relative to SPEAKER_DATA_BASE_DIR
|
|
full_sample_path = config.SPEAKER_DATA_BASE_DIR / sample_path_str
|
|
try:
|
|
if full_sample_path.is_file(): # Check if it's a file before removing
|
|
os.remove(full_sample_path)
|
|
except OSError as e:
|
|
# Log error if file deletion fails but proceed
|
|
print(f"Error deleting sample file {full_sample_path}: {e}")
|
|
return True
|
|
return False
|
|
|
|
# Example usage (for testing, not part of the service itself)
|
|
if __name__ == "__main__":
|
|
service = SpeakerManagementService()
|
|
print("Initial speakers:", service.get_speakers())
|
|
|
|
# This part would require a mock UploadFile to run directly
|
|
# print("\nAdding a new speaker (manual test setup needed for UploadFile)")
|
|
# class MockUploadFile:
|
|
# def __init__(self, filename, content):
|
|
# self.filename = filename
|
|
# self._content = content
|
|
# async def read(self): return self._content
|
|
# async def close(self): pass
|
|
# import asyncio
|
|
# async def test_add():
|
|
# mock_file = MockUploadFile("test.wav", b"dummy audio content")
|
|
# new_speaker = await service.add_speaker(name="Test Speaker", audio_file=mock_file)
|
|
# print("\nAdded speaker:", new_speaker)
|
|
# print("Speakers after add:", service.get_speakers())
|
|
# return new_speaker.id
|
|
# speaker_id_to_delete = asyncio.run(test_add())
|
|
# if speaker_id_to_delete:
|
|
# print(f"\nDeleting speaker {speaker_id_to_delete}")
|
|
# service.delete_speaker(speaker_id_to_delete)
|
|
# print("Speakers after delete:", service.get_speakers())
|