chatterbox-ui/backend/app/routers/dialog.py

277 lines
12 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pathlib import Path
import shutil
import os
from app.models.dialog_models import DialogRequest, DialogResponse
from app.services.tts_service import TTSService
from app.services.speaker_service import SpeakerManagementService
from app.services.dialog_processor_service import DialogProcessorService
from app.services.audio_manipulation_service import AudioManipulationService
from app import config
from typing import AsyncIterator
from app.services.model_manager import ModelManager
router = APIRouter()
# --- Dependency Injection for Services ---
# These can be more sophisticated with a proper DI container or FastAPI's Depends system if services had complex init.
# For now, direct instantiation or simple Depends is fine.
async def get_tts_service() -> AsyncIterator[TTSService]:
"""Dependency that holds a usage token for the duration of the request."""
manager = ModelManager.instance()
async with manager.using():
service = await manager.get_service()
yield service
def get_speaker_management_service():
return SpeakerManagementService()
def get_dialog_processor_service(
tts_service: TTSService = Depends(get_tts_service),
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
):
return DialogProcessorService(tts_service=tts_service, speaker_service=speaker_service)
def get_audio_manipulation_service():
return AudioManipulationService()
# --- Helper imports ---
from app.models.dialog_models import SpeechItem, SilenceItem
from app.services.tts_service import TTSService
from app.services.audio_manipulation_service import AudioManipulationService
from app.services.speaker_service import SpeakerManagementService
from fastapi import Body
import uuid
from pathlib import Path
@router.post("/generate_line")
async def generate_line(
item: dict = Body(...),
tts_service: TTSService = Depends(get_tts_service),
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service),
speaker_service: SpeakerManagementService = Depends(get_speaker_management_service)
):
"""
Generate audio for a single dialog line (speech or silence).
Returns the URL of the generated audio file, or error details on failure.
"""
try:
if item.get("type") == "speech":
speech = SpeechItem(**item)
filename_base = f"line_{uuid.uuid4().hex}"
out_dir = Path(config.DIALOG_GENERATED_DIR)
# Get speaker sample path
speaker_info = speaker_service.get_speaker_by_id(speech.speaker_id)
if not speaker_info or not getattr(speaker_info, 'sample_path', None):
raise HTTPException(status_code=404, detail=f"Speaker sample for ID '{speech.speaker_id}' not found.")
speaker_sample_path = speaker_info.sample_path
# Ensure absolute path
if not os.path.isabs(speaker_sample_path):
speaker_sample_path = str((Path(config.SPEAKER_SAMPLES_DIR) / Path(speaker_sample_path).name).resolve())
# Generate speech (async)
out_path = await tts_service.generate_speech(
text=speech.text,
speaker_sample_path=speaker_sample_path,
output_filename_base=filename_base,
speaker_id=speech.speaker_id,
output_dir=out_dir,
exaggeration=speech.exaggeration,
cfg_weight=speech.cfg_weight,
temperature=speech.temperature
)
audio_url = f"/generated_audio/{out_path.name}"
return {"audio_url": audio_url}
elif item.get("type") == "silence":
silence = SilenceItem(**item)
filename = f"silence_{uuid.uuid4().hex}.wav"
out_dir = Path(config.DIALOG_GENERATED_DIR)
out_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
out_path = out_dir / filename
try:
# Generate silence
silence_tensor = audio_manipulator.generate_silence(silence.duration)
import torchaudio
torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate)
if not out_path.exists() or out_path.stat().st_size == 0:
raise HTTPException(
status_code=500,
detail=f"Failed to generate silence. Output file not created: {out_path}"
)
audio_url = f"/generated_audio/{filename}"
return {"audio_url": audio_url}
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(
status_code=500,
detail=f"Error generating silence: {str(e)}"
)
else:
raise HTTPException(
status_code=400,
detail=f"Unknown dialog item type: {item.get('type')}. Expected 'speech' or 'silence'."
)
except HTTPException as he:
# Re-raise HTTP exceptions as-is
raise he
except Exception as e:
import traceback
tb = traceback.format_exc()
error_detail = f"Unexpected error: {str(e)}\n\nTraceback:\n{tb}"
print(error_detail) # Log to console for debugging
raise HTTPException(
status_code=500,
detail=error_detail
)
# Removed per-request load/unload in favor of ModelManager idle eviction.
async def process_dialog_flow(
request: DialogRequest,
dialog_processor: DialogProcessorService,
audio_manipulator: AudioManipulationService,
background_tasks: BackgroundTasks
) -> DialogResponse:
"""Core logic for processing the dialog request."""
processing_log_entries = []
concatenated_audio_file_path = None
zip_archive_file_path = None
final_temp_dir_path_str = None
try:
# 1. Process dialog to generate segments
# The DialogProcessorService creates its own temp dir for segments
dialog_processing_result = await dialog_processor.process_dialog(
dialog_items=[item.model_dump() for item in request.dialog_items],
output_base_name=request.output_base_name
)
processing_log_entries.append(dialog_processing_result['log'])
segment_details = dialog_processing_result['segment_files']
temp_segment_dir = Path(dialog_processing_result['temp_dir'])
final_temp_dir_path_str = str(temp_segment_dir)
# Filter out error segments for concatenation and zipping
valid_segment_paths_for_concat = [
Path(s['path']) for s in segment_details
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
]
# Create a list of dicts suitable for concatenation service (speech paths and silence durations)
items_for_concatenation = []
for s_detail in segment_details:
if s_detail['type'] == 'speech' and s_detail.get('path') and Path(s_detail['path']).exists():
items_for_concatenation.append({'type': 'speech', 'path': s_detail['path']})
elif s_detail['type'] == 'silence' and 'duration' in s_detail:
items_for_concatenation.append({'type': 'silence', 'duration': s_detail['duration']})
# Errors are already logged by DialogProcessor
if not any(item['type'] == 'speech' for item in items_for_concatenation):
message = "No valid speech segments were generated. Cannot create concatenated audio or ZIP."
processing_log_entries.append(message)
return DialogResponse(
log="\n".join(processing_log_entries),
temp_dir_path=final_temp_dir_path_str,
error_message=message
)
# 2. Concatenate audio segments
config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
concat_filename = f"{request.output_base_name}_concatenated.wav"
concatenated_audio_file_path = config.DIALOG_GENERATED_DIR / concat_filename
audio_manipulator.concatenate_audio_segments(
segment_results=items_for_concatenation,
output_concatenated_path=concatenated_audio_file_path
)
processing_log_entries.append(f"Concatenated audio saved to: {concatenated_audio_file_path}")
# 3. Create ZIP archive
zip_filename = f"{request.output_base_name}_dialog_output.zip"
zip_archive_path = config.DIALOG_GENERATED_DIR / zip_filename
# Collect all valid generated speech segment files for zipping
individual_segment_paths = [
Path(s['path']) for s in segment_details
if s['type'] == 'speech' and s.get('path') and Path(s['path']).exists()
]
# concatenated_audio_file_path is already defined and checked for existence before this block
audio_manipulator.create_zip_archive(
segment_file_paths=individual_segment_paths,
concatenated_audio_path=concatenated_audio_file_path,
output_zip_path=zip_archive_path
)
processing_log_entries.append(f"ZIP archive created at: {zip_archive_path}")
# Schedule cleanup of the temporary segment directory
# background_tasks.add_task(shutil.rmtree, temp_segment_dir, ignore_errors=True)
# processing_log_entries.append(f"Scheduled cleanup for temporary segment directory: {temp_segment_dir}")
# For now, let's not auto-delete, so user can inspect. Cleanup can be a separate endpoint/job.
processing_log_entries.append(f"Temporary segment directory for inspection: {temp_segment_dir}")
return DialogResponse(
log="\n".join(processing_log_entries),
# URLs should be relative to a static serving path, e.g., /generated_audio/
# For now, just returning the name, assuming they are in DIALOG_OUTPUT_DIR
concatenated_audio_url=f"/generated_audio/{concat_filename}",
zip_archive_url=f"/generated_audio/{zip_filename}",
temp_dir_path=final_temp_dir_path_str
)
except FileNotFoundError as e:
error_msg = f"File not found during dialog generation: {e}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=404, detail=error_msg)
except ValueError as e:
error_msg = f"Invalid value or configuration: {e}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
except RuntimeError as e:
error_msg = f"Runtime error during dialog generation: {e}"
processing_log_entries.append(error_msg)
# This could be a 500 if it's an unexpected server error
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
import traceback
error_msg = f"An unexpected error occurred: {e}\n{traceback.format_exc()}"
processing_log_entries.append(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
finally:
# Ensure logs are captured even if an early exception occurs before full response construction
if not concatenated_audio_file_path and not zip_archive_file_path and processing_log_entries:
print("Dialog generation failed. Log: \n" + "\n".join(processing_log_entries))
@router.post("/generate", response_model=DialogResponse)
async def generate_dialog_endpoint(
request: DialogRequest,
background_tasks: BackgroundTasks,
tts_service: TTSService = Depends(get_tts_service),
dialog_processor: DialogProcessorService = Depends(get_dialog_processor_service),
audio_manipulator: AudioManipulationService = Depends(get_audio_manipulation_service)
):
"""
Generates a dialog from a list of speech and silence items.
- Processes text into manageable chunks.
- Generates speech for each chunk using the specified speaker.
- Inserts silences as requested.
- Concatenates all audio segments into a single file.
- Creates a ZIP archive of all individual segments and the concatenated file.
"""
# Execute core processing; ModelManager dependency keeps the model marked "in use".
return await process_dialog_flow(
request=request,
dialog_processor=dialog_processor,
audio_manipulator=audio_manipulator,
background_tasks=background_tasks,
)