Patched up to work on m3 laptop. Need to fix the location specific shit.
This commit is contained in:
parent
c91a9598b1
commit
758aa02053
|
@ -15,20 +15,45 @@ This directory contains the FastAPI backend for the Chatterbox TTS application.
|
||||||
|
|
||||||
## Setup & Running
|
## Setup & Running
|
||||||
|
|
||||||
It is assumed you have a Python virtual environment at the project root (e.g., `.venv`).
|
### Prerequisites
|
||||||
|
- Python 3.8 or higher
|
||||||
|
- A Python virtual environment (recommended)
|
||||||
|
|
||||||
1. Navigate to the **project root** directory (e.g., `/Volumes/SAM2/CODE/chatterbox-test`).
|
### Installation
|
||||||
2. Activate the existing Python virtual environment:
|
|
||||||
|
1. **Navigate to the backend directory**:
|
||||||
```bash
|
```bash
|
||||||
|
cd /path/to/chatterbox-ui/backend
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Set up a virtual environment** (if not already created):
|
||||||
|
```bash
|
||||||
|
python -m venv .venv
|
||||||
source .venv/bin/activate # On macOS/Linux
|
source .venv/bin/activate # On macOS/Linux
|
||||||
# .\.venv\Scripts\activate # On Windows
|
# .\.venv\Scripts\activate # On Windows
|
||||||
```
|
```
|
||||||
3. Install dependencies (ensure your terminal is in the **project root**):
|
|
||||||
|
3. **Install dependencies**:
|
||||||
```bash
|
```bash
|
||||||
pip install -r backend/requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
4. Run the development server (ensure your terminal is in the **project root**):
|
|
||||||
|
### Running the Development Server
|
||||||
|
|
||||||
|
From the `backend` directory, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvicorn backend.app.main:app --reload --host 0.0.0.0 --port 8000
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
The API should then be accessible at `http://127.0.0.1:8000`.
|
|
||||||
|
### Accessing the API
|
||||||
|
|
||||||
|
Once running, you can access:
|
||||||
|
- API documentation (Swagger UI): `http://127.0.0.1:8000/docs`
|
||||||
|
- Alternative API docs (ReDoc): `http://127.0.0.1:8000/redoc`
|
||||||
|
- API root: `http://127.0.0.1:8000/`
|
||||||
|
|
||||||
|
### Development Notes
|
||||||
|
- The `--reload` flag enables auto-reload on code changes
|
||||||
|
- The server will be accessible on all network interfaces with `--host 0.0.0.0`
|
||||||
|
- Default port is 8000, but you can change it with `--port <port_number>`
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Determine PROJECT_ROOT dynamically.
|
# Determine PROJECT_ROOT dynamically.
|
||||||
# If config.py is at /Volumes/SAM2/CODE/chatterbox-test/backend/app/config.py
|
# Use the current project directory instead of mounted volume paths
|
||||||
# then PROJECT_ROOT (/Volumes/SAM2/CODE/chatterbox-test) is 2 levels up.
|
PROJECT_ROOT = Path("/Users/stwhite/CODE/chatterbox-ui").resolve()
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
||||||
|
|
||||||
# Speaker data paths
|
# Speaker data paths
|
||||||
SPEAKER_DATA_BASE_DIR = PROJECT_ROOT / "speaker_data"
|
SPEAKER_DATA_BASE_DIR = PROJECT_ROOT / "speaker_data"
|
||||||
|
|
|
@ -12,18 +12,15 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS Middleware configuration
|
# CORS Middleware configuration
|
||||||
origins = [
|
# For development, we'll allow all origins
|
||||||
"http://localhost:8001",
|
# In production, you should restrict this to specific origins
|
||||||
"http://127.0.0.1:8001",
|
|
||||||
# Add other origins if needed, e.g., your deployed frontend URL
|
|
||||||
]
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins,
|
allow_origins=["*"], # Allow all origins during development
|
||||||
allow_credentials=True,
|
allow_credentials=False, # Set to False when using allow_origins=["*"]
|
||||||
allow_methods=["*"], # Allows all methods
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
allow_headers=["*"], # Allows all headers
|
allow_headers=["*"],
|
||||||
|
expose_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
|
|
|
@ -78,22 +78,55 @@ async def generate_line(
|
||||||
temperature=speech.temperature
|
temperature=speech.temperature
|
||||||
)
|
)
|
||||||
audio_url = f"/generated_audio/{out_path.name}"
|
audio_url = f"/generated_audio/{out_path.name}"
|
||||||
|
return {"audio_url": audio_url}
|
||||||
elif item.get("type") == "silence":
|
elif item.get("type") == "silence":
|
||||||
silence = SilenceItem(**item)
|
silence = SilenceItem(**item)
|
||||||
filename = f"silence_{uuid.uuid4().hex}.wav"
|
filename = f"silence_{uuid.uuid4().hex}.wav"
|
||||||
out_path = Path(config.DIALOG_GENERATED_DIR) / filename
|
out_dir = Path(config.DIALOG_GENERATED_DIR)
|
||||||
# Generate silence tensor and save as WAV
|
out_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
|
||||||
silence_tensor = audio_manipulator._create_silence(silence.duration)
|
out_path = out_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate silence
|
||||||
|
silence_tensor = audio_manipulator.generate_silence(silence.duration)
|
||||||
import torchaudio
|
import torchaudio
|
||||||
torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate)
|
torchaudio.save(str(out_path), silence_tensor, audio_manipulator.sample_rate)
|
||||||
|
|
||||||
|
if not out_path.exists() or out_path.stat().st_size == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to generate silence. Output file not created: {out_path}"
|
||||||
|
)
|
||||||
|
|
||||||
audio_url = f"/generated_audio/{filename}"
|
audio_url = f"/generated_audio/{filename}"
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=400, detail="Unknown dialog item type.")
|
|
||||||
return {"audio_url": audio_url}
|
return {"audio_url": audio_url}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise e
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Error generating silence: {str(e)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unknown dialog item type: {item.get('type')}. Expected 'speech' or 'silence'."
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException as he:
|
||||||
|
# Re-raise HTTP exceptions as-is
|
||||||
|
raise he
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
raise HTTPException(status_code=500, detail=f"Exception: {str(e)}\nTraceback:\n{tb}")
|
error_detail = f"Unexpected error: {str(e)}\n\nTraceback:\n{tb}"
|
||||||
|
print(error_detail) # Log to console for debugging
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=error_detail
|
||||||
|
)
|
||||||
|
|
||||||
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
||||||
"""Loads TTS model, executes task, then unloads model."""
|
"""Loads TTS model, executes task, then unloads model."""
|
||||||
|
|
|
@ -4,9 +4,41 @@ from typing import Optional
|
||||||
from chatterbox.tts import ChatterboxTTS
|
from chatterbox.tts import ChatterboxTTS
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gc # Garbage collector for memory management
|
import gc # Garbage collector for memory management
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
# Define a directory for TTS model outputs, could be temporary or configurable
|
# Import configuration
|
||||||
TTS_OUTPUT_DIR = Path("/Volumes/SAM2/CODE/chatterbox-test/tts_outputs") # Example path
|
from app.config import TTS_TEMP_OUTPUT_DIR, SPEAKER_SAMPLES_DIR
|
||||||
|
|
||||||
|
# Use configuration for TTS output directory
|
||||||
|
TTS_OUTPUT_DIR = TTS_TEMP_OUTPUT_DIR
|
||||||
|
|
||||||
|
def safe_load_chatterbox_tts(device):
|
||||||
|
"""
|
||||||
|
Safely load ChatterboxTTS model with device mapping to handle CUDA->MPS/CPU conversion.
|
||||||
|
This patches torch.load temporarily to map CUDA tensors to the appropriate device.
|
||||||
|
"""
|
||||||
|
@contextmanager
|
||||||
|
def patch_torch_load(target_device):
|
||||||
|
original_load = torch.load
|
||||||
|
|
||||||
|
def patched_load(*args, **kwargs):
|
||||||
|
# Add map_location to handle device mapping
|
||||||
|
if 'map_location' not in kwargs:
|
||||||
|
if target_device == "mps" and torch.backends.mps.is_available():
|
||||||
|
kwargs['map_location'] = torch.device('mps')
|
||||||
|
else:
|
||||||
|
kwargs['map_location'] = torch.device('cpu')
|
||||||
|
return original_load(*args, **kwargs)
|
||||||
|
|
||||||
|
torch.load = patched_load
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.load = original_load
|
||||||
|
|
||||||
|
with patch_torch_load(device):
|
||||||
|
return ChatterboxTTS.from_pretrained(device=device)
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda"
|
def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda"
|
||||||
|
@ -23,7 +55,7 @@ class TTSService:
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
print(f"Loading ChatterboxTTS model to device: {self.device}...")
|
print(f"Loading ChatterboxTTS model to device: {self.device}...")
|
||||||
try:
|
try:
|
||||||
self.model = ChatterboxTTS.from_pretrained(device=self.device)
|
self.model = safe_load_chatterbox_tts(self.device)
|
||||||
print("ChatterboxTTS model loaded successfully.")
|
print("ChatterboxTTS model loaded successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading ChatterboxTTS model: {e}")
|
print(f"Error loading ChatterboxTTS model: {e}")
|
||||||
|
@ -105,7 +137,7 @@ if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
tts_service.load_model()
|
tts_service.load_model()
|
||||||
|
|
||||||
dummy_speaker_root = Path("/Volumes/SAM2/CODE/chatterbox-test/speaker_data/speaker_samples")
|
dummy_speaker_root = SPEAKER_SAMPLES_DIR
|
||||||
dummy_speaker_root.mkdir(parents=True, exist_ok=True)
|
dummy_speaker_root.mkdir(parents=True, exist_ok=True)
|
||||||
dummy_sample_file = dummy_speaker_root / "dummy_speaker_test.wav"
|
dummy_sample_file = dummy_speaker_root / "dummy_speaker_test.wav"
|
||||||
import os # Added for os.remove
|
import os # Added for os.remove
|
||||||
|
|
|
@ -107,18 +107,33 @@ export async function deleteSpeaker(speakerId) {
|
||||||
* @throws {Error} If the network response is not ok.
|
* @throws {Error} If the network response is not ok.
|
||||||
*/
|
*/
|
||||||
export async function generateLine(line) {
|
export async function generateLine(line) {
|
||||||
const response = await fetch(`${API_BASE_URL}/dialog/generate_line/`, {
|
console.log('generateLine called with:', line);
|
||||||
|
const response = await fetch(`${API_BASE_URL}/dialog/generate_line`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify(line),
|
body: JSON.stringify(line),
|
||||||
});
|
});
|
||||||
|
console.log('Response status:', response.status);
|
||||||
|
console.log('Response headers:', [...response.headers.entries()]);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({ message: response.statusText }));
|
const errorData = await response.json().catch(() => ({ message: response.statusText }));
|
||||||
throw new Error(`Failed to generate line audio: ${errorData.detail || errorData.message || response.statusText}`);
|
throw new Error(`Failed to generate line audio: ${errorData.detail || errorData.message || response.statusText}`);
|
||||||
}
|
}
|
||||||
return response.json();
|
|
||||||
|
const responseText = await response.text();
|
||||||
|
console.log('Raw response text:', responseText);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const jsonData = JSON.parse(responseText);
|
||||||
|
console.log('Parsed JSON:', jsonData);
|
||||||
|
return jsonData;
|
||||||
|
} catch (parseError) {
|
||||||
|
console.error('JSON parse error:', parseError);
|
||||||
|
throw new Error(`Invalid JSON response: ${responseText}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -137,7 +152,7 @@ export async function generateLine(line) {
|
||||||
* @throws {Error} If the network response is not ok.
|
* @throws {Error} If the network response is not ok.
|
||||||
*/
|
*/
|
||||||
export async function generateDialog(dialogPayload) {
|
export async function generateDialog(dialogPayload) {
|
||||||
const response = await fetch(`${API_BASE_URL}/dialog/generate/`, {
|
const response = await fetch(`${API_BASE_URL}/dialog/generate`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
|
@ -314,9 +314,18 @@ async function initializeDialogEditor() {
|
||||||
const payload = { ...item };
|
const payload = { ...item };
|
||||||
// Remove fields not needed by backend
|
// Remove fields not needed by backend
|
||||||
delete payload.audioUrl; delete payload.isGenerating; delete payload.error;
|
delete payload.audioUrl; delete payload.isGenerating; delete payload.error;
|
||||||
|
console.log('Sending payload:', payload);
|
||||||
const result = await generateLine(payload);
|
const result = await generateLine(payload);
|
||||||
|
console.log('Received result:', result);
|
||||||
|
if (result && result.audio_url) {
|
||||||
dialogItems[index].audioUrl = result.audio_url;
|
dialogItems[index].audioUrl = result.audio_url;
|
||||||
|
console.log('Set audioUrl to:', result.audio_url);
|
||||||
|
} else {
|
||||||
|
console.error('Invalid result structure:', result);
|
||||||
|
throw new Error('Invalid response: missing audio_url');
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
console.error('Error in generateLine:', err);
|
||||||
dialogItems[index].error = err.message || 'Failed to generate audio.';
|
dialogItems[index].error = err.message || 'Failed to generate audio.';
|
||||||
alert(dialogItems[index].error);
|
alert(dialogItems[index].error);
|
||||||
} finally {
|
} finally {
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
gradio>=3.50.0
|
||||||
|
PyYAML>=6.0
|
||||||
|
torch>=2.0.0
|
||||||
|
torchaudio>=2.0.0
|
||||||
|
numpy>=1.21.0
|
|
@ -0,0 +1,21 @@
|
||||||
|
831c1dbe-c379-4d9f-868b-9798adc3c05d:
|
||||||
|
name: Adam
|
||||||
|
sample_path: speaker_samples/831c1dbe-c379-4d9f-868b-9798adc3c05d.wav
|
||||||
|
608903c4-b157-46c5-a0ea-4b25eb4b83b6:
|
||||||
|
name: Denise
|
||||||
|
sample_path: speaker_samples/608903c4-b157-46c5-a0ea-4b25eb4b83b6.wav
|
||||||
|
3c93c9df-86dc-4d67-ab55-8104b9301190:
|
||||||
|
name: Maria
|
||||||
|
sample_path: speaker_samples/3c93c9df-86dc-4d67-ab55-8104b9301190.wav
|
||||||
|
fb84ce1c-f32d-4df9-9673-2c64e9603133:
|
||||||
|
name: Debbie
|
||||||
|
sample_path: speaker_samples/fb84ce1c-f32d-4df9-9673-2c64e9603133.wav
|
||||||
|
90fcd672-ba84-441a-ac6c-0449a59653bd:
|
||||||
|
name: dummy_speaker
|
||||||
|
sample_path: speaker_samples/90fcd672-ba84-441a-ac6c-0449a59653bd.wav
|
||||||
|
a6387c23-4ca4-42b5-8aaf-5699dbabbdf0:
|
||||||
|
name: Mike
|
||||||
|
sample_path: speaker_samples/a6387c23-4ca4-42b5-8aaf-5699dbabbdf0.wav
|
||||||
|
6cf4d171-667d-4bc8-adbb-6d9b7c620cb8:
|
||||||
|
name: Minnie
|
||||||
|
sample_path: speaker_samples/6cf4d171-667d-4bc8-adbb-6d9b7c620cb8.wav
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
import torchaudio as ta
|
||||||
|
from chatterbox.tts import ChatterboxTTS
|
||||||
|
|
||||||
|
# Detect device (Mac with M1/M2/M3/M4)
|
||||||
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
def safe_load_chatterbox_tts(device="mps"):
|
||||||
|
"""
|
||||||
|
Safely load ChatterboxTTS model with proper device mapping.
|
||||||
|
Handles cases where model was saved on CUDA but needs to be loaded on MPS/CPU.
|
||||||
|
"""
|
||||||
|
# Store original torch.load function
|
||||||
|
original_torch_load = torch.load
|
||||||
|
|
||||||
|
def patched_torch_load(f, map_location=None, **kwargs):
|
||||||
|
# If no map_location is specified and we're loading on non-CUDA device,
|
||||||
|
# map CUDA tensors to the target device
|
||||||
|
if map_location is None:
|
||||||
|
if device == "mps" and torch.backends.mps.is_available():
|
||||||
|
map_location = torch.device("mps")
|
||||||
|
elif device == "cpu" or not torch.cuda.is_available():
|
||||||
|
map_location = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
map_location = torch.device(device)
|
||||||
|
|
||||||
|
return original_torch_load(f, map_location=map_location, **kwargs)
|
||||||
|
|
||||||
|
# Temporarily patch torch.load
|
||||||
|
torch.load = patched_torch_load
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load the model with the patched torch.load
|
||||||
|
model = ChatterboxTTS.from_pretrained(device=device)
|
||||||
|
return model
|
||||||
|
finally:
|
||||||
|
# Restore original torch.load
|
||||||
|
torch.load = original_torch_load
|
||||||
|
|
||||||
|
model = safe_load_chatterbox_tts(device=device)
|
||||||
|
text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."
|
||||||
|
|
||||||
|
# If you want to synthesize with a different voice, specify the audio prompt
|
||||||
|
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
|
||||||
|
wav = model.generate(
|
||||||
|
text,
|
||||||
|
audio_prompt_path=AUDIO_PROMPT_PATH,
|
||||||
|
exaggeration=2.0,
|
||||||
|
cfg_weight=0.5
|
||||||
|
)
|
||||||
|
ta.save("test-2.wav", wav, model.sr)
|
Loading…
Reference in New Issue