diff --git a/backend/app/routers/dialog.py b/backend/app/routers/dialog.py index 2661512..88a44c4 100644 --- a/backend/app/routers/dialog.py +++ b/backend/app/routers/dialog.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from pathlib import Path import shutil +import os from app.models.dialog_models import DialogRequest, DialogResponse from app.services.tts_service import TTSService @@ -32,6 +33,68 @@ def get_audio_manipulation_service(): return AudioManipulationService() # --- Helper function to manage TTS model loading/unloading --- + +from app.models.dialog_models import SpeechItem, SilenceItem +from app.services.tts_service import TTSService +from app.services.audio_manipulation_service import AudioManipulationService +from app.services.speaker_service import SpeakerManagementService +from fastapi import Body +import uuid +from pathlib import Path + +@router.post("/generate_line") +async def generate_line( + item: dict = Body(...), + tts_service: TTSService = Depends(get_tts_service), + audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service), + speaker_service: SpeakerManagementService = Depends(get_speaker_management_service) +): + """ + Generate audio for a single dialog line (speech or silence). + Returns the URL of the generated audio file, or error details on failure. + """ + try: + if item.get("type") == "speech": + speech = SpeechItem(**item) + filename_base = f"line_{uuid.uuid4().hex}" + out_dir = Path(config.DIALOG_GENERATED_DIR) + # Get speaker sample path + speaker_info = speaker_service.get_speaker_by_id(speech.speaker_id) + if not speaker_info or not getattr(speaker_info, 'sample_path', None): + raise HTTPException(status_code=404, detail=f"Speaker sample for ID '{speech.speaker_id}' not found.") + speaker_sample_path = speaker_info.sample_path + # Ensure absolute path + if not os.path.isabs(speaker_sample_path): + speaker_sample_path = str((Path(config.SPEAKER_SAMPLES_DIR) / Path(speaker_sample_path).name).resolve()) + # Generate speech (async) + out_path = await tts_service.generate_speech( + text=speech.text, + speaker_sample_path=speaker_sample_path, + output_filename_base=filename_base, + speaker_id=speech.speaker_id, + output_dir=out_dir, + exaggeration=speech.exaggeration, + cfg_weight=speech.cfg_weight, + temperature=speech.temperature + ) + audio_url = f"/generated_audio/{out_path.name}" + elif item.get("type") == "silence": + silence = SilenceItem(**item) + filename = f"silence_{uuid.uuid4().hex}.wav" + out_path = Path(config.DIALOG_GENERATED_DIR) / filename + # Generate silence tensor and save as WAV + silence_tensor = audio_manipulator._create_silence(silence.duration) + import torchaudio + torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate) + audio_url = f"/generated_audio/{filename}" + else: + raise HTTPException(status_code=400, detail="Unknown dialog item type.") + return {"audio_url": audio_url} + except Exception as e: + import traceback + tb = traceback.format_exc() + raise HTTPException(status_code=500, detail=f"Exception: {str(e)}\nTraceback:\n{tb}") + async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs): """Loads TTS model, executes task, then unloads model.""" try: diff --git a/backend/run_api_test.py b/backend/run_api_test.py index f993d65..5a6efd2 100644 --- a/backend/run_api_test.py +++ b/backend/run_api_test.py @@ -16,7 +16,7 @@ DIALOG_PAYLOAD = { "dialog_items": [ { "type": "speech", - "speaker_id": "dummy_speaker", # Ensure this speaker exists in your speakers.yaml and has a sample .wav + "speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd", # Correct UUID for dummy_speaker "text": "This is a test from the Python script. One, two, three.", "exaggeration": 1.5, "cfg_weight": 4.0, @@ -28,7 +28,7 @@ DIALOG_PAYLOAD = { }, { "type": "speech", - "speaker_id": "dummy_speaker", + "speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd", "text": "Testing complete. All systems nominal." }, { @@ -104,5 +104,49 @@ def run_test(): print(f"An unexpected error occurred: {e}") print("Test FAILED (Unexpected error)") +def test_generate_line_speech(): + url = f"{API_BASE_URL}/generate_line" + payload = { + "type": "speech", + "speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd", # Correct UUID for dummy_speaker + "text": "This is a per-line TTS test.", + "exaggeration": 1.0, + "cfg_weight": 2.0, + "temperature": 0.8 + } + print(f"\nTesting /generate_line with speech item: {payload}") + response = requests.post(url, json=payload) + print(f"Status: {response.status_code}") + try: + data = response.json() + print(f"Response: {json.dumps(data, indent=2)}") + if response.status_code == 200 and "audio_url" in data: + print("Speech line test PASSED.") + else: + print("Speech line test FAILED.") + except Exception as e: + print(f"Speech line test FAILED: {e}") + +def test_generate_line_silence(): + url = f"{API_BASE_URL}/generate_line" + payload = { + "type": "silence", + "duration": 1.25 + } + print(f"\nTesting /generate_line with silence item: {payload}") + response = requests.post(url, json=payload) + print(f"Status: {response.status_code}") + try: + data = response.json() + print(f"Response: {json.dumps(data, indent=2)}") + if response.status_code == 200 and "audio_url" in data: + print("Silence line test PASSED.") + else: + print("Silence line test FAILED.") + except Exception as e: + print(f"Silence line test FAILED: {e}") + if __name__ == "__main__": run_test() + test_generate_line_speech() + test_generate_line_silence() diff --git a/speaker_data/speakers.yaml b/speaker_data/speakers.yaml index 42ae75b..fd057f4 100644 --- a/speaker_data/speakers.yaml +++ b/speaker_data/speakers.yaml @@ -10,3 +10,6 @@ fb84ce1c-f32d-4df9-9673-2c64e9603133: name: Debbie sample_path: speaker_samples/fb84ce1c-f32d-4df9-9673-2c64e9603133.wav +90fcd672-ba84-441a-ac6c-0449a59653bd: + name: dummy_speaker + sample_path: speaker_samples/90fcd672-ba84-441a-ac6c-0449a59653bd.wav