From 758aa020530a1d0dfe3147eb66c4b52773c2001e Mon Sep 17 00:00:00 2001 From: Steve White Date: Sat, 7 Jun 2025 16:06:38 -0500 Subject: [PATCH] Patched up to work on m3 laptop. Need to fix the location specific shit. --- backend/README.md | 57 ++++++++++++++++++------- backend/app/config.py | 5 +-- backend/app/main.py | 17 +++----- backend/app/routers/dialog.py | 51 ++++++++++++++++++---- backend/app/services/tts_service.py | 40 +++++++++++++++-- frontend/js/api.js | 21 +++++++-- frontend/js/app.js | 11 ++++- requirements.txt | 5 +++ speaker_data/speaker_data/speakers.yaml | 21 +++++++++ test.py | 51 ++++++++++++++++++++++ 10 files changed, 233 insertions(+), 46 deletions(-) create mode 100644 requirements.txt create mode 100644 speaker_data/speaker_data/speakers.yaml create mode 100644 test.py diff --git a/backend/README.md b/backend/README.md index 03ab194..ab87caf 100644 --- a/backend/README.md +++ b/backend/README.md @@ -15,20 +15,45 @@ This directory contains the FastAPI backend for the Chatterbox TTS application. ## Setup & Running -It is assumed you have a Python virtual environment at the project root (e.g., `.venv`). +### Prerequisites +- Python 3.8 or higher +- A Python virtual environment (recommended) -1. Navigate to the **project root** directory (e.g., `/Volumes/SAM2/CODE/chatterbox-test`). -2. Activate the existing Python virtual environment: - ```bash - source .venv/bin/activate # On macOS/Linux - # .\.venv\Scripts\activate # On Windows - ``` -3. Install dependencies (ensure your terminal is in the **project root**): - ```bash - pip install -r backend/requirements.txt - ``` -4. Run the development server (ensure your terminal is in the **project root**): - ```bash - uvicorn backend.app.main:app --reload --host 0.0.0.0 --port 8000 - ``` -The API should then be accessible at `http://127.0.0.1:8000`. +### Installation + +1. **Navigate to the backend directory**: + ```bash + cd /path/to/chatterbox-ui/backend + ``` + +2. **Set up a virtual environment** (if not already created): + ```bash + python -m venv .venv + source .venv/bin/activate # On macOS/Linux + # .\.venv\Scripts\activate # On Windows + ``` + +3. **Install dependencies**: + ```bash + pip install -r requirements.txt + ``` + +### Running the Development Server + +From the `backend` directory, run: + +```bash +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 +``` + +### Accessing the API + +Once running, you can access: +- API documentation (Swagger UI): `http://127.0.0.1:8000/docs` +- Alternative API docs (ReDoc): `http://127.0.0.1:8000/redoc` +- API root: `http://127.0.0.1:8000/` + +### Development Notes +- The `--reload` flag enables auto-reload on code changes +- The server will be accessible on all network interfaces with `--host 0.0.0.0` +- Default port is 8000, but you can change it with `--port ` diff --git a/backend/app/config.py b/backend/app/config.py index 5ba1601..a9c2f9a 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,9 +1,8 @@ from pathlib import Path # Determine PROJECT_ROOT dynamically. -# If config.py is at /Volumes/SAM2/CODE/chatterbox-test/backend/app/config.py -# then PROJECT_ROOT (/Volumes/SAM2/CODE/chatterbox-test) is 2 levels up. -PROJECT_ROOT = Path(__file__).resolve().parents[2] +# Use the current project directory instead of mounted volume paths +PROJECT_ROOT = Path("/Users/stwhite/CODE/chatterbox-ui").resolve() # Speaker data paths SPEAKER_DATA_BASE_DIR = PROJECT_ROOT / "speaker_data" diff --git a/backend/app/main.py b/backend/app/main.py index 2d7849b..788428a 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -12,18 +12,15 @@ app = FastAPI( ) # CORS Middleware configuration -origins = [ - "http://localhost:8001", - "http://127.0.0.1:8001", - # Add other origins if needed, e.g., your deployed frontend URL -] - +# For development, we'll allow all origins +# In production, you should restrict this to specific origins app.add_middleware( CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allow all origins during development + allow_credentials=False, # Set to False when using allow_origins=["*"] + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + expose_headers=["*"], ) # Include routers diff --git a/backend/app/routers/dialog.py b/backend/app/routers/dialog.py index 88a44c4..adb0508 100644 --- a/backend/app/routers/dialog.py +++ b/backend/app/routers/dialog.py @@ -78,22 +78,55 @@ async def generate_line( temperature=speech.temperature ) audio_url = f"/generated_audio/{out_path.name}" + return {"audio_url": audio_url} elif item.get("type") == "silence": silence = SilenceItem(**item) filename = f"silence_{uuid.uuid4().hex}.wav" - out_path = Path(config.DIALOG_GENERATED_DIR) / filename - # Generate silence tensor and save as WAV - silence_tensor = audio_manipulator._create_silence(silence.duration) - import torchaudio - torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate) - audio_url = f"/generated_audio/{filename}" + out_dir = Path(config.DIALOG_GENERATED_DIR) + out_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists + out_path = out_dir / filename + + try: + # Generate silence + silence_tensor = audio_manipulator.generate_silence(silence.duration) + import torchaudio + torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate) + + if not out_path.exists() or out_path.stat().st_size == 0: + raise HTTPException( + status_code=500, + detail=f"Failed to generate silence. Output file not created: {out_path}" + ) + + audio_url = f"/generated_audio/{filename}" + return {"audio_url": audio_url} + + except Exception as e: + if isinstance(e, HTTPException): + raise e + raise HTTPException( + status_code=500, + detail=f"Error generating silence: {str(e)}" + ) else: - raise HTTPException(status_code=400, detail="Unknown dialog item type.") - return {"audio_url": audio_url} + raise HTTPException( + status_code=400, + detail=f"Unknown dialog item type: {item.get('type')}. Expected 'speech' or 'silence'." + ) + + except HTTPException as he: + # Re-raise HTTP exceptions as-is + raise he + except Exception as e: import traceback tb = traceback.format_exc() - raise HTTPException(status_code=500, detail=f"Exception: {str(e)}\nTraceback:\n{tb}") + error_detail = f"Unexpected error: {str(e)}\n\nTraceback:\n{tb}" + print(error_detail) # Log to console for debugging + raise HTTPException( + status_code=500, + detail=error_detail + ) async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs): """Loads TTS model, executes task, then unloads model.""" diff --git a/backend/app/services/tts_service.py b/backend/app/services/tts_service.py index 266dd1c..2a34469 100644 --- a/backend/app/services/tts_service.py +++ b/backend/app/services/tts_service.py @@ -4,9 +4,41 @@ from typing import Optional from chatterbox.tts import ChatterboxTTS from pathlib import Path import gc # Garbage collector for memory management +import os +from contextlib import contextmanager -# Define a directory for TTS model outputs, could be temporary or configurable -TTS_OUTPUT_DIR = Path("/Volumes/SAM2/CODE/chatterbox-test/tts_outputs") # Example path +# Import configuration +from app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR + +# Use configuration for TTS output directory +TTS_OUTPUT_DIR = TTS_TEMP_OUTPUT_DIR + +def safe_load_chatterbox_tts(device): + """ + Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion. + This patches torch.load temporarily to map CUDA tensors to the appropriate device. + """ + @contextmanager + def patch_torch_load(target_device): + original_load = torch.load + + def patched_load(*args, **kwargs): + # Add map_location to handle device mapping + if 'map_location' not in kwargs: + if target_device == "mps" and torch.backends.mps.is_available(): + kwargs['map_location'] = torch.device('mps') + else: + kwargs['map_location'] = torch.device('cpu') + return original_load(*args, **kwargs) + + torch.load = patched_load + try: + yield + finally: + torch.load = original_load + + with patch_torch_load(device): + return ChatterboxTTS.from_pretrained(device=device) class TTSService: def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda" @@ -23,7 +55,7 @@ class TTSService: if self.model is None: print(f"Loading ChatterboxTTS model to device: {self.device}...") try: - self.model = ChatterboxTTS.from_pretrained(device=self.device) + self.model = safe_load_chatterbox_tts(self.device) print("ChatterboxTTS model loaded successfully.") except Exception as e: print(f"Error loading ChatterboxTTS model: {e}") @@ -105,7 +137,7 @@ if __name__ == "__main__": try: tts_service.load_model() - dummy_speaker_root = Path("/Volumes/SAM2/CODE/chatterbox-test/speaker_data/speaker_samples") + dummy_speaker_root = SPEAKER_SAMPLES_DIR dummy_speaker_root.mkdir(parents=True, exist_ok=True) dummy_sample_file = dummy_speaker_root / "dummy_speaker_test.wav" import os # Added for os.remove diff --git a/frontend/js/api.js b/frontend/js/api.js index 611f5ee..efeeaba 100644 --- a/frontend/js/api.js +++ b/frontend/js/api.js @@ -107,18 +107,33 @@ export async function deleteSpeaker(speakerId) { * @throws {Error} If the network response is not ok. */ export async function generateLine(line) { - const response = await fetch(`${API_BASE_URL}/dialog/generate_line/`, { + console.log('generateLine called with:', line); + const response = await fetch(`${API_BASE_URL}/dialog/generate_line`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(line), }); + console.log('Response status:', response.status); + console.log('Response headers:', [...response.headers.entries()]); + if (!response.ok) { const errorData = await response.json().catch(() => ({ message: response.statusText })); throw new Error(`Failed to generate line audio: ${errorData.detail || errorData.message || response.statusText}`); } - return response.json(); + + const responseText = await response.text(); + console.log('Raw response text:', responseText); + + try { + const jsonData = JSON.parse(responseText); + console.log('Parsed JSON:', jsonData); + return jsonData; + } catch (parseError) { + console.error('JSON parse error:', parseError); + throw new Error(`Invalid JSON response: ${responseText}`); + } } /** @@ -137,7 +152,7 @@ export async function generateLine(line) { * @throws {Error} If the network response is not ok. */ export async function generateDialog(dialogPayload) { - const response = await fetch(`${API_BASE_URL}/dialog/generate/`, { + const response = await fetch(`${API_BASE_URL}/dialog/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/frontend/js/app.js b/frontend/js/app.js index 39ee24a..7210552 100644 --- a/frontend/js/app.js +++ b/frontend/js/app.js @@ -314,9 +314,18 @@ async function initializeDialogEditor() { const payload = { ...item }; // Remove fields not needed by backend delete payload.audioUrl; delete payload.isGenerating; delete payload.error; + console.log('Sending payload:', payload); const result = await generateLine(payload); - dialogItems[index].audioUrl = result.audio_url; + console.log('Received result:', result); + if (result && result.audio_url) { + dialogItems[index].audioUrl = result.audio_url; + console.log('Set audioUrl to:', result.audio_url); + } else { + console.error('Invalid result structure:', result); + throw new Error('Invalid response: missing audio_url'); + } } catch (err) { + console.error('Error in generateLine:', err); dialogItems[index].error = err.message || 'Failed to generate audio.'; alert(dialogItems[index].error); } finally { diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fda3a2e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +gradio>=3.50.0 +PyYAML>=6.0 +torch>=2.0.0 +torchaudio>=2.0.0 +numpy>=1.21.0 diff --git a/speaker_data/speaker_data/speakers.yaml b/speaker_data/speaker_data/speakers.yaml new file mode 100644 index 0000000..3331196 --- /dev/null +++ b/speaker_data/speaker_data/speakers.yaml @@ -0,0 +1,21 @@ +831c1dbe-c379-4d9f-868b-9798adc3c05d: + name: Adam + sample_path: speaker_samples/831c1dbe-c379-4d9f-868b-9798adc3c05d.wav +608903c4-b157-46c5-a0ea-4b25eb4b83b6: + name: Denise + sample_path: speaker_samples/608903c4-b157-46c5-a0ea-4b25eb4b83b6.wav +3c93c9df-86dc-4d67-ab55-8104b9301190: + name: Maria + sample_path: speaker_samples/3c93c9df-86dc-4d67-ab55-8104b9301190.wav +fb84ce1c-f32d-4df9-9673-2c64e9603133: + name: Debbie + sample_path: speaker_samples/fb84ce1c-f32d-4df9-9673-2c64e9603133.wav +90fcd672-ba84-441a-ac6c-0449a59653bd: + name: dummy_speaker + sample_path: speaker_samples/90fcd672-ba84-441a-ac6c-0449a59653bd.wav +a6387c23-4ca4-42b5-8aaf-5699dbabbdf0: + name: Mike + sample_path: speaker_samples/a6387c23-4ca4-42b5-8aaf-5699dbabbdf0.wav +6cf4d171-667d-4bc8-adbb-6d9b7c620cb8: + name: Minnie + sample_path: speaker_samples/6cf4d171-667d-4bc8-adbb-6d9b7c620cb8.wav diff --git a/test.py b/test.py new file mode 100644 index 0000000..033b6f5 --- /dev/null +++ b/test.py @@ -0,0 +1,51 @@ +import torch +import torchaudio as ta +from chatterbox.tts import ChatterboxTTS + +# Detect device (Mac with M1/M2/M3/M4) +device = "mps" if torch.backends.mps.is_available() else "cpu" + +def safe_load_chatterbox_tts(device="mps"): + """ + Safely load ChatterboxTTS model with proper device mapping. + Handles cases where model was saved on CUDA but needs to be loaded on MPS/CPU. + """ + # Store original torch.load function + original_torch_load = torch.load + + def patched_torch_load(f, map_location=None, **kwargs): + # If no map_location is specified and we're loading on non-CUDA device, + # map CUDA tensors to the target device + if map_location is None: + if device == "mps" and torch.backends.mps.is_available(): + map_location = torch.device("mps") + elif device == "cpu" or not torch.cuda.is_available(): + map_location = torch.device("cpu") + else: + map_location = torch.device(device) + + return original_torch_load(f, map_location=map_location, **kwargs) + + # Temporarily patch torch.load + torch.load = patched_torch_load + + try: + # Load the model with the patched torch.load + model = ChatterboxTTS.from_pretrained(device=device) + return model + finally: + # Restore original torch.load + torch.load = original_torch_load + +model = safe_load_chatterbox_tts(device=device) +text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day." + +# If you want to synthesize with a different voice, specify the audio prompt +AUDIO_PROMPT_PATH = "YOUR_FILE.wav" +wav = model.generate( + text, + audio_prompt_path=AUDIO_PROMPT_PATH, + exaggeration=2.0, + cfg_weight=0.5 + ) +ta.save("test-2.wav", wav, model.sr)