190 lines
8.7 KiB
Python
190 lines
8.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from pathlib import Path
|
|
import shutil
|
|
|
|
from app.models.dialog_models import DialogRequest, DialogResponse
|
|
from app.services.tts_service import TTSService
|
|
from app.services.speaker_service import SpeakerManagementService
|
|
from app.services.dialog_processor_service import DialogProcessorService
|
|
from app.services.audio_manipulation_service import AudioManipulationService
|
|
from app import config
|
|
|
|
router = APIRouter()
|
|
|
|
# --- Dependency Injection for Services ---
|
|
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
|
|
# For now, direct instantiation or simple Depends is fine.
|
|
|
|
def get_tts_service():
|
|
# Consider making device configurable
|
|
return TTSService(device="mps")
|
|
|
|
def get_speaker_management_service():
|
|
return SpeakerManagementService()
|
|
|
|
def get_dialog_processor_service(
|
|
tts_service: TTSService = Depends(get_tts_service),
|
|
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
|
|
):
|
|
return DialogProcessorService(tts_service=tts_service, speaker_service=speaker_service)
|
|
|
|
def get_audio_manipulation_service():
|
|
return AudioManipulationService()
|
|
|
|
# --- Helper function to manage TTS model loading/unloading ---
|
|
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
|
"""Loads TTS model, executes task, then unloads model."""
|
|
try:
|
|
print("API: Loading TTS model...")
|
|
tts_service.load_model()
|
|
return await task_function(*args, **kwargs)
|
|
except Exception as e:
|
|
# Log or handle specific exceptions if needed before re-raising
|
|
print(f"API: Error during TTS model lifecycle or task execution: {e}")
|
|
raise
|
|
finally:
|
|
print("API: Unloading TTS model...")
|
|
tts_service.unload_model()
|
|
|
|
async def process_dialog_flow(
|
|
request: DialogRequest,
|
|
dialog_processor: DialogProcessorService,
|
|
audio_manipulator: AudioManipulationService,
|
|
background_tasks: BackgroundTasks
|
|
) -> DialogResponse:
|
|
"""Core logic for processing the dialog request."""
|
|
processing_log_entries = []
|
|
concatenated_audio_file_path = None
|
|
zip_archive_file_path = None
|
|
final_temp_dir_path_str = None
|
|
|
|
try:
|
|
# 1. Process dialog to generate segments
|
|
# The DialogProcessorService creates its own temp dir for segments
|
|
dialog_processing_result = await dialog_processor.process_dialog(
|
|
dialog_items=[item.model_dump() for item in request.dialog_items],
|
|
output_base_name=request.output_base_name
|
|
)
|
|
processing_log_entries.append(dialog_processing_result['log'])
|
|
segment_details = dialog_processing_result['segment_files']
|
|
temp_segment_dir = Path(dialog_processing_result['temp_dir'])
|
|
final_temp_dir_path_str = str(temp_segment_dir)
|
|
|
|
# Filter out error segments for concatenation and zipping
|
|
valid_segment_paths_for_concat = [
|
|
Path(s['path']) for s in segment_details
|
|
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
|
|
]
|
|
|
|
# Create a list of dicts suitable for concatenation service (speech paths and silence durations)
|
|
items_for_concatenation = []
|
|
for s_detail in segment_details:
|
|
if s_detail['type'] == 'speech' and s_detail.get('path') and Path(s_detail['path']).exists():
|
|
items_for_concatenation.append({'type': 'speech', 'path': s_detail['path']})
|
|
elif s_detail['type'] == 'silence' and 'duration' in s_detail:
|
|
items_for_concatenation.append({'type': 'silence', 'duration': s_detail['duration']})
|
|
# Errors are already logged by DialogProcessor
|
|
|
|
if not any(item['type'] == 'speech' for item in items_for_concatenation):
|
|
message = "No valid speech segments were generated. Cannot create concatenated audio or ZIP."
|
|
processing_log_entries.append(message)
|
|
return DialogResponse(
|
|
log="\n".join(processing_log_entries),
|
|
temp_dir_path=final_temp_dir_path_str,
|
|
error_message=message
|
|
)
|
|
|
|
# 2. Concatenate audio segments
|
|
config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
|
concat_filename = f"{request.output_base_name}_concatenated.wav"
|
|
concatenated_audio_file_path = config.DIALOG_GENERATED_DIR / concat_filename
|
|
|
|
audio_manipulator.concatenate_audio_segments(
|
|
segment_results=items_for_concatenation,
|
|
output_concatenated_path=concatenated_audio_file_path
|
|
)
|
|
processing_log_entries.append(f"Concatenated audio saved to: {concatenated_audio_file_path}")
|
|
|
|
# 3. Create ZIP archive
|
|
zip_filename = f"{request.output_base_name}_dialog_output.zip"
|
|
zip_archive_path = config.DIALOG_GENERATED_DIR / zip_filename
|
|
|
|
# Collect all valid generated speech segment files for zipping
|
|
individual_segment_paths = [
|
|
Path(s['path']) for s in segment_details
|
|
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
|
|
]
|
|
|
|
# concatenated_audio_file_path is already defined and checked for existence before this block
|
|
|
|
audio_manipulator.create_zip_archive(
|
|
segment_file_paths=individual_segment_paths,
|
|
concatenated_audio_path=concatenated_audio_file_path,
|
|
output_zip_path=zip_archive_path
|
|
)
|
|
processing_log_entries.append(f"ZIP archive created at: {zip_archive_path}")
|
|
|
|
# Schedule cleanup of the temporary segment directory
|
|
# background_tasks.add_task(shutil.rmtree, temp_segment_dir, ignore_errors=True)
|
|
# processing_log_entries.append(f"Scheduled cleanup for temporary segment directory: {temp_segment_dir}")
|
|
# For now, let's not auto-delete, so user can inspect. Cleanup can be a separate endpoint/job.
|
|
processing_log_entries.append(f"Temporary segment directory for inspection: {temp_segment_dir}")
|
|
|
|
return DialogResponse(
|
|
log="\n".join(processing_log_entries),
|
|
# URLs should be relative to a static serving path, e.g., /generated_audio/
|
|
# For now, just returning the name, assuming they are in DIALOG_OUTPUT_DIR
|
|
concatenated_audio_url=f"/generated_audio/{concat_filename}",
|
|
zip_archive_url=f"/generated_audio/{zip_filename}",
|
|
temp_dir_path=final_temp_dir_path_str
|
|
)
|
|
|
|
except FileNotFoundError as e:
|
|
error_msg = f"File not found during dialog generation: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=404, detail=error_msg)
|
|
except ValueError as e:
|
|
error_msg = f"Invalid value or configuration: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=400, detail=error_msg)
|
|
except RuntimeError as e:
|
|
error_msg = f"Runtime error during dialog generation: {e}"
|
|
processing_log_entries.append(error_msg)
|
|
# This could be a 500 if it's an unexpected server error
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|
except Exception as e:
|
|
import traceback
|
|
error_msg = f"An unexpected error occurred: {e}\n{traceback.format_exc()}"
|
|
processing_log_entries.append(error_msg)
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|
finally:
|
|
# Ensure logs are captured even if an early exception occurs before full response construction
|
|
if not concatenated_audio_file_path and not zip_archive_file_path and processing_log_entries:
|
|
print("Dialog generation failed. Log: \n" + "\n".join(processing_log_entries))
|
|
|
|
@router.post("/generate", response_model=DialogResponse)
|
|
async def generate_dialog_endpoint(
|
|
request: DialogRequest,
|
|
background_tasks: BackgroundTasks,
|
|
tts_service: TTSService = Depends(get_tts_service),
|
|
dialog_processor: DialogProcessorService = Depends(get_dialog_processor_service),
|
|
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service)
|
|
):
|
|
"""
|
|
Generates a dialog from a list of speech and silence items.
|
|
- Processes text into manageable chunks.
|
|
- Generates speech for each chunk using the specified speaker.
|
|
- Inserts silences as requested.
|
|
- Concatenates all audio segments into a single file.
|
|
- Creates a ZIP archive of all individual segments and the concatenated file.
|
|
"""
|
|
# Wrap the core processing logic with model loading/unloading
|
|
return await manage_tts_model_lifecycle(
|
|
tts_service,
|
|
process_dialog_flow,
|
|
request=request,
|
|
dialog_processor=dialog_processor,
|
|
audio_manipulator=audio_manipulator,
|
|
background_tasks=background_tasks
|
|
)
|