from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from pathlib import Path import shutil from app.models.dialog_models import DialogRequest, DialogResponse from app.services.tts_service import TTSService from app.services.speaker_service import SpeakerManagementService from app.services.dialog_processor_service import DialogProcessorService from app.services.audio_manipulation_service import AudioManipulationService from app import config router = APIRouter() # --- Dependency Injection for Services --- # These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init. # For now, direct instantiation or simple Depends is fine. def get_tts_service(): # Consider making device configurable return TTSService(device="mps") def get_speaker_management_service(): return SpeakerManagementService() def get_dialog_processor_service( tts_service: TTSService = Depends(get_tts_service), speaker_service: SpeakerManagementService = Depends(get_speaker_management_service) ): return DialogProcessorService(tts_service=tts_service, speaker_service=speaker_service) def get_audio_manipulation_service(): return AudioManipulationService() # --- Helper function to manage TTS model loading/unloading --- async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs): """Loads TTS model, executes task, then unloads model.""" try: print("API: Loading TTS model...") tts_service.load_model() return await task_function(*args, **kwargs) except Exception as e: # Log or handle specific exceptions if needed before re-raising print(f"API: Error during TTS model lifecycle or task execution: {e}") raise finally: print("API: Unloading TTS model...") tts_service.unload_model() async def process_dialog_flow( request: DialogRequest, dialog_processor: DialogProcessorService, audio_manipulator: AudioManipulationService, background_tasks: BackgroundTasks ) -> DialogResponse: """Core logic for processing the dialog request.""" processing_log_entries = [] concatenated_audio_file_path = None zip_archive_file_path = None final_temp_dir_path_str = None try: # 1. Process dialog to generate segments # The DialogProcessorService creates its own temp dir for segments dialog_processing_result = await dialog_processor.process_dialog( dialog_items=[item.model_dump() for item in request.dialog_items], output_base_name=request.output_base_name ) processing_log_entries.append(dialog_processing_result['log']) segment_details = dialog_processing_result['segment_files'] temp_segment_dir = Path(dialog_processing_result['temp_dir']) final_temp_dir_path_str = str(temp_segment_dir) # Filter out error segments for concatenation and zipping valid_segment_paths_for_concat = [ Path(s['path']) for s in segment_details if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists() ] # Create a list of dicts suitable for concatenation service (speech paths and silence durations) items_for_concatenation = [] for s_detail in segment_details: if s_detail['type'] == 'speech' and s_detail.get('path') and Path(s_detail['path']).exists(): items_for_concatenation.append({'type': 'speech', 'path': s_detail['path']}) elif s_detail['type'] == 'silence' and 'duration' in s_detail: items_for_concatenation.append({'type': 'silence', 'duration': s_detail['duration']}) # Errors are already logged by DialogProcessor if not any(item['type'] == 'speech' for item in items_for_concatenation): message = "No valid speech segments were generated. Cannot create concatenated audio or ZIP." processing_log_entries.append(message) return DialogResponse( log="\n".join(processing_log_entries), temp_dir_path=final_temp_dir_path_str, error_message=message ) # 2. Concatenate audio segments config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True) concat_filename = f"{request.output_base_name}_concatenated.wav" concatenated_audio_file_path = config.DIALOG_GENERATED_DIR / concat_filename audio_manipulator.concatenate_audio_segments( segment_results=items_for_concatenation, output_concatenated_path=concatenated_audio_file_path ) processing_log_entries.append(f"Concatenated audio saved to: {concatenated_audio_file_path}") # 3. Create ZIP archive zip_filename = f"{request.output_base_name}_dialog_output.zip" zip_archive_path = config.DIALOG_GENERATED_DIR / zip_filename # Collect all valid generated speech segment files for zipping individual_segment_paths = [ Path(s['path']) for s in segment_details if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists() ] # concatenated_audio_file_path is already defined and checked for existence before this block audio_manipulator.create_zip_archive( segment_file_paths=individual_segment_paths, concatenated_audio_path=concatenated_audio_file_path, output_zip_path=zip_archive_path ) processing_log_entries.append(f"ZIP archive created at: {zip_archive_path}") # Schedule cleanup of the temporary segment directory # background_tasks.add_task(shutil.rmtree, temp_segment_dir, ignore_errors=True) # processing_log_entries.append(f"Scheduled cleanup for temporary segment directory: {temp_segment_dir}") # For now, let's not auto-delete, so user can inspect. Cleanup can be a separate endpoint/job. processing_log_entries.append(f"Temporary segment directory for inspection: {temp_segment_dir}") return DialogResponse( log="\n".join(processing_log_entries), # URLs should be relative to a static serving path, e.g., /generated_audio/ # For now, just returning the name, assuming they are in DIALOG_OUTPUT_DIR concatenated_audio_url=f"/generated_audio/{concat_filename}", zip_archive_url=f"/generated_audio/{zip_filename}", temp_dir_path=final_temp_dir_path_str ) except FileNotFoundError as e: error_msg = f"File not found during dialog generation: {e}" processing_log_entries.append(error_msg) raise HTTPException(status_code=404, detail=error_msg) except ValueError as e: error_msg = f"Invalid value or configuration: {e}" processing_log_entries.append(error_msg) raise HTTPException(status_code=400, detail=error_msg) except RuntimeError as e: error_msg = f"Runtime error during dialog generation: {e}" processing_log_entries.append(error_msg) # This could be a 500 if it's an unexpected server error raise HTTPException(status_code=500, detail=error_msg) except Exception as e: import traceback error_msg = f"An unexpected error occurred: {e}\n{traceback.format_exc()}" processing_log_entries.append(error_msg) raise HTTPException(status_code=500, detail=error_msg) finally: # Ensure logs are captured even if an early exception occurs before full response construction if not concatenated_audio_file_path and not zip_archive_file_path and processing_log_entries: print("Dialog generation failed. Log: \n" + "\n".join(processing_log_entries)) @router.post("/generate", response_model=DialogResponse) async def generate_dialog_endpoint( request: DialogRequest, background_tasks: BackgroundTasks, tts_service: TTSService = Depends(get_tts_service), dialog_processor: DialogProcessorService = Depends(get_dialog_processor_service), audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service) ): """ Generates a dialog from a list of speech and silence items. - Processes text into manageable chunks. - Generates speech for each chunk using the specified speaker. - Inserts silences as requested. - Concatenates all audio segments into a single file. - Creates a ZIP archive of all individual segments and the concatenated file. """ # Wrap the core processing logic with model loading/unloading return await manage_tts_model_lifecycle( tts_service, process_dialog_flow, request=request, dialog_processor=dialog_processor, audio_manipulator=audio_manipulator, background_tasks=background_tasks )