added single line generation to the backend
This commit is contained in:
parent
6ccdd18463
commit
0261b86ad2
|
@ -1,6 +1,7 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
|
import os
|
||||||
|
|
||||||
from app.models.dialog_models import DialogRequest, DialogResponse
|
from app.models.dialog_models import DialogRequest, DialogResponse
|
||||||
from app.services.tts_service import TTSService
|
from app.services.tts_service import TTSService
|
||||||
|
@ -32,6 +33,68 @@ def get_audio_manipulation_service():
|
||||||
return AudioManipulationService()
|
return AudioManipulationService()
|
||||||
|
|
||||||
# --- Helper function to manage TTS model loading/unloading ---
|
# --- Helper function to manage TTS model loading/unloading ---
|
||||||
|
|
||||||
|
from app.models.dialog_models import SpeechItem, SilenceItem
|
||||||
|
from app.services.tts_service import TTSService
|
||||||
|
from app.services.audio_manipulation_service import AudioManipulationService
|
||||||
|
from app.services.speaker_service import SpeakerManagementService
|
||||||
|
from fastapi import Body
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
@router.post("/generate_line")
|
||||||
|
async def generate_line(
|
||||||
|
item: dict = Body(...),
|
||||||
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
|
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service),
|
||||||
|
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio for a single dialog line (speech or silence).
|
||||||
|
Returns the URL of the generated audio file, or error details on failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if item.get("type") == "speech":
|
||||||
|
speech = SpeechItem(**item)
|
||||||
|
filename_base = f"line_{uuid.uuid4().hex}"
|
||||||
|
out_dir = Path(config.DIALOG_GENERATED_DIR)
|
||||||
|
# Get speaker sample path
|
||||||
|
speaker_info = speaker_service.get_speaker_by_id(speech.speaker_id)
|
||||||
|
if not speaker_info or not getattr(speaker_info, 'sample_path', None):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Speaker sample for ID '{speech.speaker_id}' not found.")
|
||||||
|
speaker_sample_path = speaker_info.sample_path
|
||||||
|
# Ensure absolute path
|
||||||
|
if not os.path.isabs(speaker_sample_path):
|
||||||
|
speaker_sample_path = str((Path(config.SPEAKER_SAMPLES_DIR) / Path(speaker_sample_path).name).resolve())
|
||||||
|
# Generate speech (async)
|
||||||
|
out_path = await tts_service.generate_speech(
|
||||||
|
text=speech.text,
|
||||||
|
speaker_sample_path=speaker_sample_path,
|
||||||
|
output_filename_base=filename_base,
|
||||||
|
speaker_id=speech.speaker_id,
|
||||||
|
output_dir=out_dir,
|
||||||
|
exaggeration=speech.exaggeration,
|
||||||
|
cfg_weight=speech.cfg_weight,
|
||||||
|
temperature=speech.temperature
|
||||||
|
)
|
||||||
|
audio_url = f"/generated_audio/{out_path.name}"
|
||||||
|
elif item.get("type") == "silence":
|
||||||
|
silence = SilenceItem(**item)
|
||||||
|
filename = f"silence_{uuid.uuid4().hex}.wav"
|
||||||
|
out_path = Path(config.DIALOG_GENERATED_DIR) / filename
|
||||||
|
# Generate silence tensor and save as WAV
|
||||||
|
silence_tensor = audio_manipulator._create_silence(silence.duration)
|
||||||
|
import torchaudio
|
||||||
|
torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate)
|
||||||
|
audio_url = f"/generated_audio/{filename}"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="Unknown dialog item type.")
|
||||||
|
return {"audio_url": audio_url}
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
raise HTTPException(status_code=500, detail=f"Exception: {str(e)}\nTraceback:\n{tb}")
|
||||||
|
|
||||||
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
||||||
"""Loads TTS model, executes task, then unloads model."""
|
"""Loads TTS model, executes task, then unloads model."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -16,7 +16,7 @@ DIALOG_PAYLOAD = {
|
||||||
"dialog_items": [
|
"dialog_items": [
|
||||||
{
|
{
|
||||||
"type": "speech",
|
"type": "speech",
|
||||||
"speaker_id": "dummy_speaker", # Ensure this speaker exists in your speakers.yaml and has a sample .wav
|
"speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd", # Correct UUID for dummy_speaker
|
||||||
"text": "This is a test from the Python script. One, two, three.",
|
"text": "This is a test from the Python script. One, two, three.",
|
||||||
"exaggeration": 1.5,
|
"exaggeration": 1.5,
|
||||||
"cfg_weight": 4.0,
|
"cfg_weight": 4.0,
|
||||||
|
@ -28,7 +28,7 @@ DIALOG_PAYLOAD = {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "speech",
|
"type": "speech",
|
||||||
"speaker_id": "dummy_speaker",
|
"speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd",
|
||||||
"text": "Testing complete. All systems nominal."
|
"text": "Testing complete. All systems nominal."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -104,5 +104,49 @@ def run_test():
|
||||||
print(f"An unexpected error occurred: {e}")
|
print(f"An unexpected error occurred: {e}")
|
||||||
print("Test FAILED (Unexpected error)")
|
print("Test FAILED (Unexpected error)")
|
||||||
|
|
||||||
|
def test_generate_line_speech():
|
||||||
|
url = f"{API_BASE_URL}/generate_line"
|
||||||
|
payload = {
|
||||||
|
"type": "speech",
|
||||||
|
"speaker_id": "90fcd672-ba84-441a-ac6c-0449a59653bd", # Correct UUID for dummy_speaker
|
||||||
|
"text": "This is a per-line TTS test.",
|
||||||
|
"exaggeration": 1.0,
|
||||||
|
"cfg_weight": 2.0,
|
||||||
|
"temperature": 0.8
|
||||||
|
}
|
||||||
|
print(f"\nTesting /generate_line with speech item: {payload}")
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
print(f"Status: {response.status_code}")
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
print(f"Response: {json.dumps(data, indent=2)}")
|
||||||
|
if response.status_code == 200 and "audio_url" in data:
|
||||||
|
print("Speech line test PASSED.")
|
||||||
|
else:
|
||||||
|
print("Speech line test FAILED.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Speech line test FAILED: {e}")
|
||||||
|
|
||||||
|
def test_generate_line_silence():
|
||||||
|
url = f"{API_BASE_URL}/generate_line"
|
||||||
|
payload = {
|
||||||
|
"type": "silence",
|
||||||
|
"duration": 1.25
|
||||||
|
}
|
||||||
|
print(f"\nTesting /generate_line with silence item: {payload}")
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
print(f"Status: {response.status_code}")
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
print(f"Response: {json.dumps(data, indent=2)}")
|
||||||
|
if response.status_code == 200 and "audio_url" in data:
|
||||||
|
print("Silence line test PASSED.")
|
||||||
|
else:
|
||||||
|
print("Silence line test FAILED.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Silence line test FAILED: {e}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_test()
|
run_test()
|
||||||
|
test_generate_line_speech()
|
||||||
|
test_generate_line_silence()
|
||||||
|
|
|
@ -10,3 +10,6 @@
|
||||||
fb84ce1c-f32d-4df9-9673-2c64e9603133:
|
fb84ce1c-f32d-4df9-9673-2c64e9603133:
|
||||||
name: Debbie
|
name: Debbie
|
||||||
sample_path: speaker_samples/fb84ce1c-f32d-4df9-9673-2c64e9603133.wav
|
sample_path: speaker_samples/fb84ce1c-f32d-4df9-9673-2c64e9603133.wav
|
||||||
|
90fcd672-ba84-441a-ac6c-0449a59653bd:
|
||||||
|
name: dummy_speaker
|
||||||
|
sample_path: speaker_samples/90fcd672-ba84-441a-ac6c-0449a59653bd.wav
|
||||||
|
|
Loading…
Reference in New Issue