277 lines
12 KiB
Python
277 lines
12 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from pathlib import Path
|
|
import shutil
|
|
import os
|
|
|
|
from app.models.dialog_models import DialogRequest, DialogResponse
|
|
from app.services.tts_service import TTSService
|
|
from app.services.speaker_service import SpeakerManagementService
|
|
from app.services.dialog_processor_service import DialogProcessorService
|
|
from app.services.audio_manipulation_service import AudioManipulationService
|
|
from app import config
|
|
from typing import AsyncIterator
|
|
from app.services.model_manager import ModelManager
|
|
|
|
router = APIRouter()
|
|
|
|
# --- Dependency Injection for Services ---
|
|
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
|
|
# For now, direct instantiation or simple Depends is fine.
|
|
|
|
async def get_tts_service() -> AsyncIterator[TTSService]:
|
|
"""Dependency that holds a usage token for the duration of the request."""
|
|
manager = ModelManager.instance()
|
|
async with manager.using():
|
|
service = await manager.get_service()
|
|
yield service
|
|
|
|
def get_speaker_management_service():
|
|
return SpeakerManagementService()
|
|
|
|
def get_dialog_processor_service(
|
|
tts_service: TTSService = Depends(get_tts_service),
|
|
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
|
|
):
|
|
return DialogProcessorService(tts_service=tts_service, speaker_service=speaker_service)
|
|
|
|
def get_audio_manipulation_service():
|
|
return AudioManipulationService()
|
|
|
|
# --- Helper imports ---
|
|
|
|
from app.models.dialog_models import SpeechItem, SilenceItem
|
|
from app.services.tts_service import TTSService
|
|
from app.services.audio_manipulation_service import AudioManipulationService
|
|
from app.services.speaker_service import SpeakerManagementService
|
|
from fastapi import Body
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
@router.post("/generate_line")
|
|
async def generate_line(
|
|
item: dict = Body(...),
|
|
tts_service: TTSService = Depends(get_tts_service),
|
|
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service),
|
|
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
|
|
):
|
|
"""
|
|
Generate audio for a single dialog line (speech or silence).
|
|
Returns the URL of the generated audio file, or error details on failure.
|
|
"""
|
|
try:
|
|
if item.get("type") == "speech":
|
|
speech = SpeechItem(**item)
|
|
filename_base = f"line_{uuid.uuid4().hex}"
|
|
out_dir = Path(config.DIALOG_GENERATED_DIR)
|
|
# Get speaker sample path
|
|
speaker_info = speaker_service.get_speaker_by_id(speech.speaker_id)
|
|
if not speaker_info or not getattr(speaker_info, 'sample_path', None):
|
|
raise HTTPException(status_code=404, detail=f"Speaker sample for ID '{speech.speaker_id}' not found.")
|
|
speaker_sample_path = speaker_info.sample_path
|
|
# Ensure absolute path
|
|
if not os.path.isabs(speaker_sample_path):
|
|
speaker_sample_path = str((Path(config.SPEAKER_SAMPLES_DIR) / Path(speaker_sample_path).name).resolve())
|
|
# Generate speech (async)
|
|
out_path = await tts_service.generate_speech(
|
|
text=speech.text,
|
|
speaker_sample_path=speaker_sample_path,
|
|
output_filename_base=filename_base,
|
|
speaker_id=speech.speaker_id,
|
|
output_dir=out_dir,
|
|
exaggeration=speech.exaggeration,
|
|
cfg_weight=speech.cfg_weight,
|
|
temperature=speech.temperature
|
|
)
|
|
audio_url = f"/generated_audio/{out_path.name}"
|
|
return {"audio_url": audio_url}
|
|
elif item.get("type") == "silence":
|
|
silence = SilenceItem(**item)
|
|
filename = f"silence_{uuid.uuid4().hex}.wav"
|
|
out_dir = Path(config.DIALOG_GENERATED_DIR)
|
|
out_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
|
|
out_path = out_dir / filename
|
|
|
|
try:
|
|
# Generate silence
|
|
silence_tensor = audio_manipulator.generate_silence(silence.duration)
|
|
import torchaudio
|
|
torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate)
|
|
|
|
if not out_path.exists() or out_path.stat().st_size == 0:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to generate silence. Output file not created: {out_path}"
|
|
)
|
|
|
|
audio_url = f"/generated_audio/{filename}"
|
|
return {"audio_url": audio_url}
|
|
|
|
except Exception as e:
|
|
if isinstance(e, HTTPException):
|
|
raise e
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error generating silence: {str(e)}"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unknown dialog item type: {item.get('type')}. Expected 'speech' or 'silence'."
|
|
)
|
|
|
|
except HTTPException as he:
|
|
# Re-raise HTTP exceptions as-is
|
|
raise he
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
tb = traceback.format_exc()
|
|
error_detail = f"Unexpected error: {str(e)}\n\nTraceback:\n{tb}"
|
|
print(error_detail) # Log to console for debugging
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=error_detail
|
|
)
|
|
|
|
# Removed per-request load/unload in favor of ModelManager idle eviction.
|
|
|
|
async def process_dialog_flow(
|
|
request: DialogRequest,
|
|
dialog_processor: DialogProcessorService,
|
|
audio_manipulator: AudioManipulationService,
|
|
background_tasks: BackgroundTasks
|
|
) -> DialogResponse:
|
|
"""Core logic for processing the dialog request."""
|
|
processing_log_entries = []
|
|
concatenated_audio_file_path = None
|
|
zip_archive_file_path = None
|
|
final_temp_dir_path_str = None
|
|
|
|
try:
|
|
# 1. Process dialog to generate segments
|
|
# The DialogProcessorService creates its own temp dir for segments
|
|
dialog_processing_result = await dialog_processor.process_dialog(
|
|
dialog_items=[item.model_dump() for item in request.dialog_items],
|
|
output_base_name=request.output_base_name
|
|
)
|
|
processing_log_entries.append(dialog_processing_result['log'])
|
|
segment_details = dialog_processing_result['segment_files']
|
|
temp_segment_dir = Path(dialog_processing_result['temp_dir'])
|
|
final_temp_dir_path_str = str(temp_segment_dir)
|
|
|
|
# Filter out error segments for concatenation and zipping
|
|
valid_segment_paths_for_concat = [
|
|
Path(s['path']) for s in segment_details
|
|
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
|
|
]
|
|
|
|
# Create a list of dicts suitable for concatenation service (speech paths and silence durations)
|
|
items_for_concatenation = []
|
|
for s_detail in segment_details:
|
|
if s_detail['type'] == 'speech' and s_detail.get('path') and Path(s_detail['path']).exists():
|
|
items_for_concatenation.append({'type': 'speech', 'path': s_detail['path']})
|
|
elif s_detail['type'] == 'silence' and 'duration' in s_detail:
|
|
items_for_concatenation.append({'type': 'silence', 'duration': s_detail['duration']})
|
|
# Errors are already logged by DialogProcessor
|
|
|
|
if not any(item['type'] == 'speech' for item in items_for_concatenation):
|
|
message = "No valid speech segments were generated. Cannot create concatenated audio or ZIP."
|
|
processing_log_entries.append(message)
|
|
return DialogResponse(
|
|
log="\n".join(processing_log_entries),
|
|
temp_dir_path=final_temp_dir_path_str,
|
|
error_message=message
|
|
)
|
|
|
|
# 2. Concatenate audio segments
|
|
config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
|
concat_filename = f"{request.output_base_name}_concatenated.wav"
|
|
concatenated_audio_file_path = config.DIALOG_GENERATED_DIR / concat_filename
|
|
|
|
audio_manipulator.concatenate_audio_segments(
|
|
segment_results=items_for_concatenation,
|
|
output_concatenated_path=concatenated_audio_file_path
|
|
)
|
|
processing_log_entries.append(f"Concatenated audio saved to: {concatenated_audio_file_path}")
|
|
|
|
# 3. Create ZIP archive
|
|
zip_filename = f"{request.output_base_name}_dialog_output.zip"
|
|
zip_archive_path = config.DIALOG_GENERATED_DIR / zip_filename
|
|
|
|
# Collect all valid generated speech segment files for zipping
|
|
individual_segment_paths = [
|
|
Path(s['path']) for s in segment_details
|
|
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
|
|
]
|
|
|
|
# concatenated_audio_file_path is already defined and checked for existence before this block
|
|
|
|
audio_manipulator.create_zip_archive(
|
|
segment_file_paths=individual_segment_paths,
|
|
concatenated_audio_path=concatenated_audio_file_path,
|
|
output_zip_path=zip_archive_path
|
|
)
|
|
processing_log_entries.append(f"ZIP archive created at: {zip_archive_path}")
|
|
|
|
# Schedule cleanup of the temporary segment directory
|
|
# background_tasks.add_task(shutil.rmtree, temp_segment_dir, ignore_errors=True)
|
|
# processing_log_entries.append(f"Scheduled cleanup for temporary segment directory: {temp_segment_dir}")
|
|
# For now, let's not auto-delete, so user can inspect. Cleanup can be a separate endpoint/job.
|
|
processing_log_entries.append(f"Temporary segment directory for inspection: {temp_segment_dir}")
|
|
|
|
return DialogResponse(
|
|
log="\n".join(processing_log_entries),
|
|
# URLs should be relative to a static serving path, e.g., /generated_audio/
|
|
# For now, just returning the name, assuming they are in DIALOG_OUTPUT_DIR
|
|
concatenated_audio_url=f"/generated_audio/{concat_filename}",
|
|
zip_archive_url=f"/generated_audio/{zip_filename}",
|
|
temp_dir_path=final_temp_dir_path_str
|
|
)
|
|
|
|
except FileNotFoundError as e:
|
|
error_msg = f"File not found during dialog generation: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=404, detail=error_msg)
|
|
except ValueError as e:
|
|
error_msg = f"Invalid value or configuration: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=400, detail=error_msg)
|
|
except RuntimeError as e:
|
|
error_msg = f"Runtime error during dialog generation: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
# This could be a 500 if it's an unexpected server error
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|
except Exception as e:
|
|
import traceback
|
|
error_msg = f"An unexpected error occurred: {e}\n{traceback.format_exc()}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|
finally:
|
|
# Ensure logs are captured even if an early exception occurs before full response construction
|
|
if not concatenated_audio_file_path and not zip_archive_file_path and processing_log_entries:
|
|
print("Dialog generation failed. Log: \n" + "\n".join(processing_log_entries))
|
|
|
|
@router.post("/generate", response_model=DialogResponse)
|
|
async def generate_dialog_endpoint(
|
|
request: DialogRequest,
|
|
background_tasks: BackgroundTasks,
|
|
tts_service: TTSService = Depends(get_tts_service),
|
|
dialog_processor: DialogProcessorService = Depends(get_dialog_processor_service),
|
|
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service)
|
|
):
|
|
"""
|
|
Generates a dialog from a list of speech and silence items.
|
|
- Processes text into manageable chunks.
|
|
- Generates speech for each chunk using the specified speaker.
|
|
- Inserts silences as requested.
|
|
- Concatenates all audio segments into a single file.
|
|
- Creates a ZIP archive of all individual segments and the concatenated file.
|
|
"""
|
|
# Execute core processing; ModelManager dependency keeps the model marked "in use".
|
|
return await process_dialog_flow(
|
|
request=request,
|
|
dialog_processor=dialog_processor,
|
|
audio_manipulator=audio_manipulator,
|
|
background_tasks=background_tasks,
|
|
)
|