chatterbox-ui/backend/app/routers/dialog.py

190 lines
8.7 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pathlib import Path
import shutil
from app.models.dialog_models import DialogRequest, DialogResponse
from app.services.tts_service import TTSService
from app.services.speaker_service import SpeakerManagementService
from app.services.dialog_processor_service import DialogProcessorService
from app.services.audio_manipulation_service import AudioManipulationService
from app import config
router = APIRouter()
# --- Dependency Injection for Services ---
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
# For now, direct instantiation or simple Depends is fine.
def get_tts_service():
# Consider making device configurable
return TTSService(device="mps")
def get_speaker_management_service():
return SpeakerManagementService()
def get_dialog_processor_service(
tts_service: TTSService = Depends(get_tts_service),
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
):
return DialogProcessorService(tts_service=tts_service, speaker_service=speaker_service)
def get_audio_manipulation_service():
return AudioManipulationService()
# --- Helper function to manage TTS model loading/unloading ---
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
"""Loads TTS model, executes task, then unloads model."""
try:
print("API: Loading TTS model...")
tts_service.load_model()
return await task_function(*args, **kwargs)
except Exception as e:
# Log or handle specific exceptions if needed before re-raising
print(f"API: Error during TTS model lifecycle or task execution: {e}")
raise
finally:
print("API: Unloading TTS model...")
tts_service.unload_model()
async def process_dialog_flow(
request: DialogRequest,
dialog_processor: DialogProcessorService,
audio_manipulator: AudioManipulationService,
background_tasks: BackgroundTasks
) -> DialogResponse:
"""Core logic for processing the dialog request."""
processing_log_entries = []
concatenated_audio_file_path = None
zip_archive_file_path = None
final_temp_dir_path_str = None
try:
# 1. Process dialog to generate segments
# The DialogProcessorService creates its own temp dir for segments
dialog_processing_result = await dialog_processor.process_dialog(
dialog_items=[item.model_dump() for item in request.dialog_items],
output_base_name=request.output_base_name
)
processing_log_entries.append(dialog_processing_result['log'])
segment_details = dialog_processing_result['segment_files']
temp_segment_dir = Path(dialog_processing_result['temp_dir'])
final_temp_dir_path_str = str(temp_segment_dir)
# Filter out error segments for concatenation and zipping
valid_segment_paths_for_concat = [
Path(s['path']) for s in segment_details
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
]
# Create a list of dicts suitable for concatenation service (speech paths and silence durations)
items_for_concatenation = []
for s_detail in segment_details:
if s_detail['type'] == 'speech' and s_detail.get('path') and Path(s_detail['path']).exists():
items_for_concatenation.append({'type': 'speech', 'path': s_detail['path']})
elif s_detail['type'] == 'silence' and 'duration' in s_detail:
items_for_concatenation.append({'type': 'silence', 'duration': s_detail['duration']})
# Errors are already logged by DialogProcessor
if not any(item['type'] == 'speech' for item in items_for_concatenation):
message = "No valid speech segments were generated. Cannot create concatenated audio or ZIP."
processing_log_entries.append(message)
return DialogResponse(
log="\n".join(processing_log_entries),
temp_dir_path=final_temp_dir_path_str,
error_message=message
)
# 2. Concatenate audio segments
config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
concat_filename = f"{request.output_base_name}_concatenated.wav"
concatenated_audio_file_path = config.DIALOG_GENERATED_DIR / concat_filename
audio_manipulator.concatenate_audio_segments(
segment_results=items_for_concatenation,
output_concatenated_path=concatenated_audio_file_path
)
processing_log_entries.append(f"Concatenated audio saved to: {concatenated_audio_file_path}")
# 3. Create ZIP archive
zip_filename = f"{request.output_base_name}_dialog_output.zip"
zip_archive_path = config.DIALOG_GENERATED_DIR / zip_filename
# Collect all valid generated speech segment files for zipping
individual_segment_paths = [
Path(s['path']) for s in segment_details
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
]
# concatenated_audio_file_path is already defined and checked for existence before this block
audio_manipulator.create_zip_archive(
segment_file_paths=individual_segment_paths,
concatenated_audio_path=concatenated_audio_file_path,
output_zip_path=zip_archive_path
)
processing_log_entries.append(f"ZIP archive created at: {zip_archive_path}")
# Schedule cleanup of the temporary segment directory
# background_tasks.add_task(shutil.rmtree, temp_segment_dir, ignore_errors=True)
# processing_log_entries.append(f"Scheduled cleanup for temporary segment directory: {temp_segment_dir}")
# For now, let's not auto-delete, so user can inspect. Cleanup can be a separate endpoint/job.
processing_log_entries.append(f"Temporary segment directory for inspection: {temp_segment_dir}")
return DialogResponse(
log="\n".join(processing_log_entries),
# URLs should be relative to a static serving path, e.g., /generated_audio/
# For now, just returning the name, assuming they are in DIALOG_OUTPUT_DIR
concatenated_audio_url=f"/generated_audio/{concat_filename}",
zip_archive_url=f"/generated_audio/{zip_filename}",
temp_dir_path=final_temp_dir_path_str
)
except FileNotFoundError as e:
error_msg = f"File not found during dialog generation: {e}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=404, detail=error_msg)
except ValueError as e:
error_msg = f"Invalid value or configuration: {e}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
except RuntimeError as e:
error_msg = f"Runtime error during dialog generation: {e}"
processing_log_entries.append(error_msg)
# This could be a 500 if it's an unexpected server error
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
import traceback
error_msg = f"An unexpected error occurred: {e}\n{traceback.format_exc()}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
finally:
# Ensure logs are captured even if an early exception occurs before full response construction
if not concatenated_audio_file_path and not zip_archive_file_path and processing_log_entries:
print("Dialog generation failed. Log: \n" + "\n".join(processing_log_entries))
@router.post("/generate", response_model=DialogResponse)
async def generate_dialog_endpoint(
request: DialogRequest,
background_tasks: BackgroundTasks,
tts_service: TTSService = Depends(get_tts_service),
dialog_processor: DialogProcessorService = Depends(get_dialog_processor_service),
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service)
):
"""
Generates a dialog from a list of speech and silence items.
- Processes text into manageable chunks.
- Generates speech for each chunk using the specified speaker.
- Inserts silences as requested.
- Concatenates all audio segments into a single file.
- Creates a ZIP archive of all individual segments and the concatenated file.
"""
# Wrap the core processing logic with model loading/unloading
return await manage_tts_model_lifecycle(
tts_service,
process_dialog_flow,
request=request,
dialog_processor=dialog_processor,
audio_manipulator=audio_manipulator,
background_tasks=background_tasks
)